Compare commits

...

94 Commits

Author SHA1 Message Date
07c6b89525 docs: remove project structure, boring and hard to maintain 2025-12-20 17:35:16 +01:00
817205d42f refactor: modernize using min() 2025-12-16 15:45:51 +01:00
199ac143a4 refactor: replace interface{} by any 2025-12-16 15:05:23 +01:00
aa7e259ed0 format: shfmt 2025-12-16 15:02:42 +01:00
4587609e17 refactor: create createTestRouter and test edge cases 2025-12-14 21:14:42 +01:00
33da6503e3 test: also test put/delete routes 2025-12-14 21:06:15 +01:00
cafc44ed77 test: add a test for route parameters 2025-12-14 21:04:36 +01:00
1480135e75 test: verified all routes to exist 2025-12-14 21:02:25 +01:00
02a764c736 clean: remove merged files 2025-12-14 20:52:14 +01:00
6834ad7764 refactor: merge facade, types and utils into one auth_service.go 2025-12-14 20:52:03 +01:00
dcf054046f test: fix parallel processor test expectations and setup 2025-12-10 07:30:27 +01:00
d2a788933d fix: track completed items in the main loop instead of using the index 2025-12-10 07:29:55 +01:00
18be3950dc clean: obsolete function 2025-12-09 22:07:30 +01:00
f9cb140e95 clean: removed the obsolete functions outputMessage and outputError 2025-12-09 22:06:12 +01:00
86d4835ccf feat: seed user is now uniq 2025-12-09 22:03:26 +01:00
feddb2ed43 test: new unit test for EnsureSeedUser 2025-12-09 22:03:16 +01:00
457b5c88e2 refactor: improve seed consistency validation 2025-12-09 21:53:03 +01:00
a8d363b2bf fix: templates now parse with the same func map as the page handler 2025-12-09 21:37:21 +01:00
0cd68e847c refactor: add a helper to centralize CSRF token retrieval 2025-12-09 15:58:28 +01:00
df6aeed713 docs: unocss it is 2025-12-04 20:43:30 +01:00
785faeb60c feat: update alpine to 3.23 2025-12-04 10:01:19 +01:00
0623c027ba docs: prepare CONTRIBUTING.md 2025-12-03 20:57:32 +01:00
d4e91b6034 refactor: complete refactor and better helpers use 2025-11-29 15:19:41 +01:00
7d46d3e81b clean: remove the unused expectedValue in assertHeader (always set to "") 2025-11-29 15:19:28 +01:00
216aaf3117 refactor: clean code and use new request helpers 2025-11-29 14:58:52 +01:00
435047ad0c refactor: clean code 2025-11-29 14:58:37 +01:00
b7ee8bd11d refactor: clean variable names and use new request helpers 2025-11-29 14:58:20 +01:00
040cd48be8 refactor: clean variables 2025-11-29 14:58:07 +01:00
2dd16e0e00 refactor: complete 2025-11-29 14:56:18 +01:00
d6db70cc79 refactor: clean code and variables, use new request helpers 2025-11-29 14:55:47 +01:00
58e10ade7d refactor: clean variable names and modernize code 2025-11-29 14:50:35 +01:00
7403a75d8e refactor: clean variable naming 2025-11-29 14:46:26 +01:00
b429bc11af refactor: clean code and use new request helpers 2025-11-29 14:41:38 +01:00
2ec5c28fb5 refactor: rename variables and clean code 2025-11-29 14:37:18 +01:00
3743a99e40 refactor: req -> request, rec -> recorder, reqBody -> requestBody... 2025-11-29 14:21:07 +01:00
5710921b87 refactor: use new request helpers 2025-11-29 14:17:25 +01:00
84d9c81484 refactor: rec -> recorder, req -> request and modernize loop 2025-11-29 14:15:07 +01:00
b0c2038927 feat: add new helpers to make requests properly in integration tests 2025-11-29 14:11:32 +01:00
fd88931146 refactor: variable names and modernize loop 2025-11-25 14:37:59 +01:00
6ce0f4dfad refactor: name variables 2025-11-25 10:13:10 +01:00
68e3dceefc refactor: name variables 2025-11-25 10:08:48 +01:00
cded14c526 fix: force correct mime types for static files after modifying compression middleware's buffering 2025-11-24 07:53:08 +01:00
cfca668ca6 docs: clean readme 2025-11-23 21:56:53 +01:00
279255b587 fix: don't let rate limiting fails the test 2025-11-23 21:48:11 +01:00
b83f8c2228 fix: update ValidationMiddleware to return a JSON error response when JSON decoding fails 2025-11-23 21:42:27 +01:00
aabc48128c fix: use router in handlers integration tests (for dto validation) 2025-11-23 15:10:55 +01:00
68716b977b fix: verify XSS sanitization in handler response instead of repository stub 2025-11-23 15:01:54 +01:00
dbe1600632 fix: indentation 2025-11-23 14:49:31 +01:00
458e25cf79 fix: modify compression middleware to pass through redirects immediately without buffering 2025-11-23 14:48:59 +01:00
d4595d8dbf fix: properly encoding the flash message in the redirect URL 2025-11-23 14:48:39 +01:00
c5418f4e4c docs: update swagger 2025-11-23 14:26:52 +01:00
db0369225e refactor: update references to VoteRequest 2025-11-23 14:26:45 +01:00
07ac965b3d refactor: use consistent naming (VoteRequest -> CastVoteRequest) 2025-11-23 14:26:19 +01:00
e2e5d42035 feat: add SetValidatedDTOInContext to support test helper functions 2025-11-23 14:22:59 +01:00
6e4b41894f fix: update test cases to use createCreatePostRequests 2025-11-23 14:22:35 +01:00
0a8ed2e27c fix: add explicite validation check for empty url, title and content length 2025-11-23 14:21:30 +01:00
216e8657f6 feat: add generic createRequestWithDTO along with helpers functions 2025-11-23 14:20:59 +01:00
fb7206c0a2 fix: test context handling 2025-11-23 14:20:09 +01:00
c25926514b fix: add explicit empty-field validation check in handlers 2025-11-23 14:19:54 +01:00
964785e494 docs: update swagger 2025-11-23 13:47:38 +01:00
9c67cd2a47 feat: update vote handler to use dto VoteRequest and update MountRoutes 2025-11-23 13:47:31 +01:00
8b5cc8e939 feat: add VoteRequest with its validation types 2025-11-23 13:46:51 +01:00
0e71b28615 feat: update CreateUser to use dto.RegisterRequest and update MountRoutes to apply validation middleware 2025-11-23 13:43:47 +01:00
cd740da57a feat: update methods to use validated DTOs and update MountRoutes 2025-11-23 13:43:14 +01:00
abe4a3dc88 feat: update handlers to use GetValidatedDTO instead of manual decoding and update MountRoutes to wrap handlers with WithValidation for all DTO-based routes 2025-11-23 13:42:52 +01:00
738243d945 feat: add ValidationMiddleware to RouteModuleConfig 2025-11-23 13:41:55 +01:00
4fbdfb6e4a feat: add two helpers function to retrieve validated DTOs from request context and to apply validation middleware 2025-11-23 13:41:07 +01:00
6bb3a78b88 feat: Add ValidationMiddleware to router configuration 2025-11-23 13:40:31 +01:00
54e37e59fc docs: update swagger 2025-11-23 13:35:00 +01:00
5d4b38ddc4 feat: add validation tags to request DTOs 2025-11-23 13:34:53 +01:00
7dc119ecde docs: update swagger 2025-11-23 13:17:14 +01:00
52c9f4a02b feat: add internal/dto to swagger directories 2025-11-23 13:16:44 +01:00
be91a135bc clean: empty line 2025-11-23 13:14:41 +01:00
2d7ff9778b feat: update swagger comments following dtos relocation 2025-11-23 13:14:07 +01:00
4ff3fd3583 refactor: remove UpdatePostRequest definition and update swagger comments 2025-11-23 13:13:53 +01:00
73121cad15 refactor: remove all request DTO, update swagger comments and update token related methods to use dto ones 2025-11-23 13:13:23 +01:00
c5bf1b2fd8 feat: locate post-related request DTOs 2025-11-23 13:12:36 +01:00
eedebe60d1 feat: locate auth-related request DTOs 2025-11-23 13:12:10 +01:00
80fb37371f update: fix go version and update alpine to 3.22 2025-11-23 10:44:28 +01:00
fea49fad8d fix: add missing method to mock 2025-11-21 17:07:26 +01:00
4b04461ebb style: minor formatting adjustments 2025-11-21 17:06:04 +01:00
533e8c3d46 feat: add GetByUsernamePrefixFn field and method to UserRepositoryStub 2025-11-21 17:05:48 +01:00
df568291f1 feat: add GetByUsernamePrefix implementation to MockUserRepository 2025-11-21 17:05:31 +01:00
81acce62b1 feat: add GetByUsernamePrefix method to interface and add implementation 2025-11-21 17:05:01 +01:00
989a61e7d5 feat: use getByUsernamePrefix to optimize findExistingSeedUser() 2025-11-21 17:04:35 +01:00
3ffd83b0fb feat: ignore docs in make format 2025-11-21 17:02:06 +01:00
62d466e4fa refactor: use go generics 2025-11-21 16:56:26 +01:00
0cd428d5d9 feat: use connection pooling instead of a single connection 2025-11-21 16:53:46 +01:00
5c239ad61d feat: add missing GetVoteCountsByPostID method to the errorVoteRepository test mock 2025-11-21 16:50:23 +01:00
01f2b1fe75 feat: remove loop and use GetVoteCountsByPostID 2025-11-21 16:48:48 +01:00
28134c101c feat: add GetVoteCountsByPostID to the mock for testing 2025-11-21 16:48:15 +01:00
2f78370d43 feat: GetVoteCountsByPostID: use a single sql query to returns up votes and down votes counts 2025-11-21 16:47:52 +01:00
39598a166d feat: remove redundat getbyemail call to reduce db query by 2 (1Q/user creation instead of 2) 2025-11-21 16:43:46 +01:00
fa9474d863 revert: db transaction use, avoiding the pgx RETURNING issue while maintaining data consistency 2025-11-21 16:31:06 +01:00
65 changed files with 3915 additions and 3509 deletions

1
.prettierignore Normal file
View File

@@ -0,0 +1 @@
docs/

View File

@@ -1,4 +1,4 @@
ARG GO_VERSION=1.25.3
ARG GO_VERSION=1.25.4
# Building the binary using a golang alpine image
FROM golang:${GO_VERSION}-alpine AS go-builder
@@ -11,7 +11,7 @@ ARG TARGETARCH=amd64
RUN CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build -ldflags="-s -w" -o /out/goyco ./cmd/goyco
# building the application image
FROM alpine:3.21
FROM alpine:3.23
RUN addgroup -S goyco && adduser -S -G goyco goyco \
&& apk add --no-cache ca-certificates tzdata
WORKDIR /app

View File

@@ -44,8 +44,8 @@ clean:
rm -fr dist/*
format:
$(PRETTIER) -w .
$(GO) fmt ./...
$(PRETTIER) -w . --ignore-path .prettierignore
$(GO) fmt $(shell $(GO) list ./... | grep -v 'docs')
lint:
$(GOLANGCI_LINT) run

View File

@@ -354,44 +354,6 @@ It will start the application in development mode. You can also run it as a daem
Then, use `./bin/goyco` to manage the application and notably to seed the database with sample data.
### Project Structure
```bash
goyco/
├── bin/ # Compiled binaries (created after build)
├── cmd/
│ └── goyco/ # Main CLI application entrypoint
├── docker/ # Docker Compose & related files
├── docs/ # Documentation and API specs
├── internal/
│ ├── config/ # Configuration management
│ ├── database/ # Database models and access
│ ├── dto/ # Data Transfer Objects (DTOs)
│ ├── e2e/ # End-to-end tests
│ ├── fuzz/ # Fuzz tests
│ ├── handlers/ # HTTP handlers
│ ├── integration/ # Integration tests
│ ├── middleware/ # HTTP middleware
│ ├── repositories/ # Data access layer
│ ├── security/ # Security and auth logic
│ ├── server/ # HTTP server implementation
│ ├── services/ # Business logic
│ ├── static/ # Static web assets
│ ├── templates/ # HTML templates
│ ├── testutils/ # Test helpers/utilities
│ ├── validation/ # Input validation
│ └── version/ # Version information
├── scripts/ # Utility/maintenance scripts
├── services/
│ └── goyco.service # Systemd service unit example
├── .env.example # Environment variable example
├── AUTHORS # Authors file
├── Dockerfile # Docker build file
├── LICENSE # License file
├── Makefile # Project build/test targets
└── README.md # This file
```
### Testing
```bash
@@ -437,31 +399,10 @@ This will regenerate the swagger documentation and update the `docs/swagger.json
- [ ] add right management within the app
- [ ] add an admin backoffice to manage rights, users, content and settings
- [ ] add a way to run read-only communities
- [ ] maybe use a css framework instead of raw css
- [ ] migrate raw CSS to UnoCSS
- [ ] kubernetes deployment
- [ ] store configuration in the database
## Contributing
Feedbacks are welcome!
But as it's a personal gitea and you cannot create accounts, feel free to contact me at <sandro@cazzaniga.fr> to get one.
Once you have it, follow the usual workflow:
1. Fork the repository
2. Create a feature branch
3. Make your changes
4. Add tests for new functionality
5. Ensure all tests pass
6. Submit a pull request
Then, I'll review your changes and merge them if they are good.
## License
This project is licensed under the GNU General Public License v3.0 or later (GPLv3+). See the [LICENSE](LICENSE) file for details.
---
**Goyco** - A modern news aggregation platform built with Go, PostgreSQL and most importantly, love.

View File

@@ -8,9 +8,10 @@ import (
"os"
"sync"
"gorm.io/gorm"
"goyco/internal/config"
"goyco/internal/database"
"gorm.io/gorm"
)
var ErrHelpRequested = errors.New("help requested")
@@ -40,11 +41,11 @@ var (
)
func defaultDBConnector(cfg *config.Config) (*gorm.DB, func() error, error) {
db, err := database.Connect(cfg)
poolManager, err := database.ConnectWithPool(cfg)
if err != nil {
return nil, nil, err
}
return db, func() error { return database.Close(db) }, nil
return poolManager.GetDB(), func() error { return poolManager.Close() }, nil
}
func SetDBConnector(connector DBConnector) {
@@ -118,26 +119,6 @@ func outputJSON(v interface{}) error {
return encoder.Encode(v)
}
func outputMessage(message string, args ...interface{}) {
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"message": fmt.Sprintf(message, args...),
})
} else {
fmt.Printf(message+"\n", args...)
}
}
func outputError(message string, args ...interface{}) {
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"error": fmt.Sprintf(message, args...),
})
} else {
fmt.Fprintf(os.Stderr, message+"\n", args...)
}
}
func outputWarning(message string, args ...interface{}) {
if IsJSONOutput() {
outputJSON(map[string]interface{}{

View File

@@ -5,9 +5,10 @@ import (
"fmt"
"os"
"gorm.io/gorm"
"goyco/internal/config"
"goyco/internal/database"
"gorm.io/gorm"
)
func HandleMigrateCommand(cfg *config.Config, name string, args []string) error {
@@ -37,7 +38,7 @@ func runMigrateCommand(db *gorm.DB) error {
return fmt.Errorf("run migrations: %w", err)
}
if IsJSONOutput() {
outputJSON(map[string]interface{}{
outputJSON(map[string]any{
"action": "migrations_applied",
"status": "success",
})

View File

@@ -42,14 +42,23 @@ func (p *ParallelProcessor) SetPasswordHash(hash string) {
p.passwordHash = hash
}
func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepository, count int, progress *ProgressIndicator) ([]database.User, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
type indexedResult[T any] struct {
value T
index int
}
results := make(chan userResult, count)
func processInParallel[T any](
ctx context.Context,
maxWorkers int,
count int,
processor func(index int) (T, error),
errorPrefix string,
progress *ProgressIndicator,
) ([]T, error) {
results := make(chan indexedResult[T], count)
errors := make(chan error, count)
semaphore := make(chan struct{}, p.maxWorkers)
semaphore := make(chan struct{}, maxWorkers)
var wg sync.WaitGroup
for i := range count {
@@ -65,13 +74,13 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
}
defer func() { <-semaphore }()
user, err := p.createSingleUser(userRepo, index+1)
value, err := processor(index + 1)
if err != nil {
errors <- fmt.Errorf("create user %d: %w", index+1, err)
errors <- fmt.Errorf("%s %d: %w", errorPrefix, index+1, err)
return
}
results <- userResult{user: user, index: index}
results <- indexedResult[T]{value: value, index: index}
}(i)
}
@@ -81,7 +90,7 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
close(errors)
}()
users := make([]database.User, count)
items := make([]T, count)
completed := 0
firstError := make(chan error, 1)
@@ -101,9 +110,9 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
select {
case result, ok := <-results:
if !ok {
return users, nil
return items, nil
}
users[result.index] = result.user
items[result.index] = result.value
completed++
if progress != nil {
progress.Update(completed)
@@ -111,26 +120,59 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
case err := <-firstError:
return nil, err
case <-ctx.Done():
return nil, fmt.Errorf("timeout creating users: %w", ctx.Err())
return nil, fmt.Errorf("timeout: %w", ctx.Err())
}
}
return users, nil
return items, nil
}
func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepository, count int, progress *ProgressIndicator) ([]database.User, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
return processInParallel(ctx, p.maxWorkers, count,
func(index int) (database.User, error) {
return p.createSingleUser(userRepo, index)
},
"create user",
progress,
)
}
func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepository, authorID uint, count int, progress *ProgressIndicator) ([]database.Post, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
results := make(chan postResult, count)
return processInParallel(ctx, p.maxWorkers, count,
func(index int) (database.Post, error) {
return p.createSinglePost(postRepo, authorID, index)
},
"create post",
progress,
)
}
func processItemsInParallel[T any, R any](
ctx context.Context,
maxWorkers int,
items []T,
processor func(index int, item T) (R, error),
errorPrefix string,
aggregator func(accumulator R, value R) R,
initialValue R,
progress *ProgressIndicator,
) (R, error) {
count := len(items)
results := make(chan indexedResult[R], count)
errors := make(chan error, count)
semaphore := make(chan struct{}, p.maxWorkers)
semaphore := make(chan struct{}, maxWorkers)
var wg sync.WaitGroup
for i := range count {
for i, item := range items {
wg.Add(1)
go func(index int) {
go func(index int, item T) {
defer wg.Done()
select {
@@ -141,14 +183,14 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo
}
defer func() { <-semaphore }()
post, err := p.createSinglePost(postRepo, authorID, index+1)
value, err := processor(index, item)
if err != nil {
errors <- fmt.Errorf("create post %d: %w", index+1, err)
errors <- fmt.Errorf("%s %d: %w", errorPrefix, index+1, err)
return
}
results <- postResult{post: post, index: index}
}(i)
results <- indexedResult[R]{value: value, index: index}
}(i, item)
}
go func() {
@@ -157,7 +199,7 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo
close(errors)
}()
posts := make([]database.Post, count)
accumulator := initialValue
completed := 0
firstError := make(chan error, 1)
@@ -177,36 +219,56 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo
select {
case result, ok := <-results:
if !ok {
return posts, nil
return accumulator, nil
}
posts[result.index] = result.post
accumulator = aggregator(accumulator, result.value)
completed++
if progress != nil {
progress.Update(completed)
}
case err := <-firstError:
return nil, err
return initialValue, err
case <-ctx.Done():
return nil, fmt.Errorf("timeout creating posts: %w", ctx.Err())
return initialValue, fmt.Errorf("timeout: %w", ctx.Err())
}
}
return posts, nil
return accumulator, nil
}
func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post, avgVotesPerPost int, progress *ProgressIndicator) (int, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
results := make(chan voteResult, len(posts))
errors := make(chan error, len(posts))
return processItemsInParallel(ctx, p.maxWorkers, posts,
func(index int, post database.Post) (int, error) {
return p.createVotesForPost(voteRepo, users, post, avgVotesPerPost)
},
"create votes for post",
func(acc, val int) int { return acc + val },
0,
progress,
)
}
semaphore := make(chan struct{}, p.maxWorkers)
func processItemsInParallelNoResult[T any](
ctx context.Context,
maxWorkers int,
items []T,
processor func(index int, item T) error,
errorFormatter func(index int, item T, err error) error,
progress *ProgressIndicator,
) error {
count := len(items)
errors := make(chan error, count)
completions := make(chan struct{}, count)
semaphore := make(chan struct{}, maxWorkers)
var wg sync.WaitGroup
for i, post := range posts {
for i, item := range items {
wg.Add(1)
go func(index int, post database.Post) {
go func(index int, item T) {
defer wg.Done()
select {
@@ -217,23 +279,26 @@ func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepo
}
defer func() { <-semaphore }()
votes, err := p.createVotesForPost(voteRepo, users, post, avgVotesPerPost)
err := processor(index, item)
if err != nil {
errors <- fmt.Errorf("create votes for post %d: %w", post.ID, err)
if errorFormatter != nil {
errors <- errorFormatter(index, item, err)
} else {
errors <- fmt.Errorf("process item %d: %w", index+1, err)
}
return
}
results <- voteResult{votes: votes, index: index}
}(i, post)
completions <- struct{}{}
}(i, item)
}
go func() {
wg.Wait()
close(results)
close(errors)
close(completions)
}()
totalVotes := 0
completed := 0
firstError := make(chan error, 1)
@@ -249,88 +314,39 @@ func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepo
}
}()
for completed < len(posts) {
for completed < count {
select {
case result, ok := <-results:
case _, ok := <-completions:
if !ok {
return totalVotes, nil
return nil
}
totalVotes += result.votes
completed++
if progress != nil {
progress.Update(completed)
}
case err := <-firstError:
return 0, err
case <-ctx.Done():
return 0, fmt.Errorf("timeout creating votes: %w", ctx.Err())
}
}
return totalVotes, nil
}
func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, progress *ProgressIndicator) error {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
errors := make(chan error, len(posts))
semaphore := make(chan struct{}, p.maxWorkers)
var wg sync.WaitGroup
for i, post := range posts {
wg.Add(1)
go func(index int, post database.Post) {
defer wg.Done()
select {
case semaphore <- struct{}{}:
case <-ctx.Done():
errors <- ctx.Err()
return
}
defer func() { <-semaphore }()
err := p.updateSinglePostScore(postRepo, voteRepo, post)
if err != nil {
errors <- fmt.Errorf("update post %d scores: %w", post.ID, err)
return
}
if progress != nil {
progress.Update(index + 1)
}
}(i, post)
}
go func() {
wg.Wait()
close(errors)
}()
for err := range errors {
if err != nil {
return err
case <-ctx.Done():
return fmt.Errorf("timeout: %w", ctx.Err())
}
}
return nil
}
type userResult struct {
user database.User
index int
}
func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, progress *ProgressIndicator) error {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
type postResult struct {
post database.Post
index int
}
type voteResult struct {
votes int
index int
return processItemsInParallelNoResult(ctx, p.maxWorkers, posts,
func(index int, post database.Post) error {
return p.updateSinglePostScore(postRepo, voteRepo, post)
},
func(index int, post database.Post, err error) error {
return fmt.Errorf("update post %d scores: %w", post.ID, err)
},
progress,
)
}
func (p *ParallelProcessor) generateRandomIdentifier() string {
@@ -352,12 +368,7 @@ func (p *ParallelProcessor) createSingleUser(userRepo repositories.UserRepositor
const maxRetries = 10
for range maxRetries {
user, err := userRepo.GetByEmail(email)
if err == nil {
return *user, nil
}
user = &database.User{
user := &database.User{
Username: username,
Email: email,
Password: p.passwordHash,

View File

@@ -2,15 +2,15 @@ package commands_test
import (
"errors"
"fmt"
"sync"
"testing"
"golang.org/x/crypto/bcrypt"
"goyco/cmd/goyco/commands"
"goyco/internal/database"
"goyco/internal/repositories"
"goyco/internal/testutils"
"golang.org/x/crypto/bcrypt"
)
func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
@@ -25,7 +25,7 @@ func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
wantErr bool
}{
{
name: "creates users with deterministic fields",
name: "creates users with required fields",
count: successCount,
repoFactory: func() repositories.UserRepository {
base := testutils.NewMockUserRepository()
@@ -37,14 +37,24 @@ func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
if len(got) != successCount {
t.Fatalf("expected %d users, got %d", successCount, len(got))
}
usernames := make(map[string]bool)
for i, user := range got {
expectedUsername := fmt.Sprintf("user_%d", i+1)
expectedEmail := fmt.Sprintf("user_%d@goyco.local", i+1)
if user.Username != expectedUsername {
t.Errorf("user %d username mismatch: got %q want %q", i, user.Username, expectedUsername)
if user.Username == "" {
t.Errorf("user %d expected non-empty username", i)
}
if user.Email != expectedEmail {
t.Errorf("user %d email mismatch: got %q want %q", i, user.Email, expectedEmail)
if len(user.Username) < 6 || user.Username[:5] != "user_" {
t.Errorf("user %d username should start with 'user_', got %q", i, user.Username)
}
if usernames[user.Username] {
t.Errorf("user %d duplicate username: %q", i, user.Username)
}
usernames[user.Username] = true
if user.Email == "" {
t.Errorf("user %d expected non-empty email", i)
}
if len(user.Email) < 20 || user.Email[:5] != "user_" || user.Email[len(user.Email)-12:] != "@goyco.local" {
t.Errorf("user %d email should match pattern 'user_*@goyco.local', got %q", i, user.Email)
}
if !user.EmailVerified {
t.Errorf("user %d expected EmailVerified to be true", i)
@@ -83,6 +93,11 @@ func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
t.Parallel()
repo := tt.repoFactory()
p := commands.NewParallelProcessor()
passwordHash, err := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
if err != nil {
t.Fatalf("failed to generate password hash: %v", err)
}
p.SetPasswordHash(string(passwordHash))
got, gotErr := p.CreateUsersInParallel(repo, tt.count, tt.progress)
if gotErr != nil {
if !tt.wantErr {

View File

@@ -7,7 +7,6 @@ import (
"fmt"
"math/rand"
"os"
"strings"
"sync"
"time"
@@ -36,17 +35,6 @@ func initSeedRand() {
})
}
func generateRandomIdentifier() string {
initSeedRand()
const length = 12
const chars = "abcdefghijklmnopqrstuvwxyz0123456789"
identifier := make([]byte, length)
for i := range identifier {
identifier[i] = chars[seedRandSource.Intn(len(chars))]
}
return string(identifier)
}
func HandleSeedCommand(cfg *config.Config, name string, args []string) error {
fs := newFlagSet(name, printSeedUsage)
if err := parseCommand(fs, args, name); err != nil {
@@ -57,15 +45,10 @@ func HandleSeedCommand(cfg *config.Config, name string, args []string) error {
}
return withDatabase(cfg, func(db *gorm.DB) error {
return db.Transaction(func(tx *gorm.DB) error {
userRepo := repositories.NewUserRepository(db).WithTx(tx)
postRepo := repositories.NewPostRepository(db).WithTx(tx)
voteRepo := repositories.NewVoteRepository(db).WithTx(tx)
if err := runSeedCommand(userRepo, postRepo, voteRepo, fs.Args()); err != nil {
return err
}
return nil
})
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
voteRepo := repositories.NewVoteRepository(db)
return runSeedCommand(userRepo, postRepo, voteRepo, fs.Args())
})
}
@@ -219,7 +202,7 @@ func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.Po
progress.Complete()
}
if err := validateSeedConsistency(postRepo, voteRepo, allUsers, posts); err != nil {
if err := validateSeedConsistency(voteRepo, allUsers, posts); err != nil {
return fmt.Errorf("seed consistency validation failed: %w", err)
}
@@ -242,44 +225,17 @@ func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.Po
return nil
}
func findExistingSeedUser(userRepo repositories.UserRepository) (*database.User, error) {
users, err := userRepo.GetAll(100, 0)
if err != nil {
return nil, err
}
for _, user := range users {
if len(user.Username) >= 11 && user.Username[:11] == "seed_admin_" {
if len(user.Email) >= 13 && strings.HasSuffix(user.Email, "@goyco.local") {
emailPrefix := user.Email[:len(user.Email)-13]
if len(emailPrefix) >= 11 && emailPrefix[:11] == "seed_admin_" {
return &user, nil
}
}
}
}
return nil, fmt.Errorf("no existing seed user found")
}
const (
seedUsername = "seed_admin"
seedEmail = "seed_admin@goyco.local"
)
func ensureSeedUser(userRepo repositories.UserRepository, passwordHash string) (*database.User, error) {
existingUser, err := findExistingSeedUser(userRepo)
if err == nil && existingUser != nil {
return existingUser, nil
}
randomID := generateRandomIdentifier()
seedUsername := fmt.Sprintf("seed_admin_%s", randomID)
seedEmail := fmt.Sprintf("seed_admin_%s@goyco.local", randomID)
const maxRetries = 10
for range maxRetries {
user, err := userRepo.GetByEmail(seedEmail)
if err == nil {
if user, err := userRepo.GetByUsername(seedUsername); err == nil {
return user, nil
}
user = &database.User{
user := &database.User{
Username: seedUsername,
Email: seedEmail,
Password: passwordHash,
@@ -287,66 +243,75 @@ func ensureSeedUser(userRepo repositories.UserRepository, passwordHash string) (
}
if err := userRepo.Create(user); err != nil {
randomID = generateRandomIdentifier()
seedUsername = fmt.Sprintf("seed_admin_%s", randomID)
seedEmail = fmt.Sprintf("seed_admin_%s@goyco.local", randomID)
continue
return nil, fmt.Errorf("failed to create seed user: %w", err)
}
return user, nil
}
return nil, fmt.Errorf("failed to create seed user after %d attempts", maxRetries)
}
func getVoteCounts(voteRepo repositories.VoteRepository, postID uint) (int, int, error) {
votes, err := voteRepo.GetByPostID(postID)
if err != nil {
return 0, 0, err
return voteRepo.GetVoteCountsByPostID(postID)
}
upVotes := 0
downVotes := 0
for _, vote := range votes {
switch vote.Type {
case database.VoteUp:
upVotes++
case database.VoteDown:
downVotes++
}
}
return upVotes, downVotes, nil
}
func validateSeedConsistency(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, users []database.User, posts []database.Post) error {
userIDs := make(map[uint]bool)
func validateSeedConsistency(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post) error {
userIDSet := make(map[uint]struct{}, len(users))
for _, user := range users {
userIDs[user.ID] = true
userIDSet[user.ID] = struct{}{}
}
postIDSet := make(map[uint]struct{}, len(posts))
for _, post := range posts {
postIDSet[post.ID] = struct{}{}
}
for _, post := range posts {
if post.AuthorID == nil {
return fmt.Errorf("post %d has no author", post.ID)
}
if !userIDs[*post.AuthorID] {
return fmt.Errorf("post %d has invalid author ID %d", post.ID, *post.AuthorID)
if err := validatePost(post, userIDSet); err != nil {
return err
}
votes, err := voteRepo.GetByPostID(post.ID)
if err != nil {
return fmt.Errorf("get votes for post %d: %w", post.ID, err)
return fmt.Errorf("failed to retrieve votes for post %d: %w", post.ID, err)
}
for _, vote := range votes {
if vote.UserID != nil && !userIDs[*vote.UserID] {
return fmt.Errorf("vote %d has invalid user ID %d", vote.ID, *vote.UserID)
}
if vote.PostID != post.ID {
return fmt.Errorf("vote %d has invalid post ID %d (expected %d)", vote.ID, vote.PostID, post.ID)
}
if err := validateVotesForPost(post.ID, votes, userIDSet, postIDSet); err != nil {
return err
}
}
return nil
}
func validatePost(post database.Post, userIDSet map[uint]struct{}) error {
if post.AuthorID == nil {
return fmt.Errorf("post %d has no author ID", post.ID)
}
if _, exists := userIDSet[*post.AuthorID]; !exists {
return fmt.Errorf("post %d references non-existent author ID %d", post.ID, *post.AuthorID)
}
return nil
}
func validateVotesForPost(postID uint, votes []database.Vote, userIDSet map[uint]struct{}, postIDSet map[uint]struct{}) error {
for _, vote := range votes {
if vote.PostID != postID {
return fmt.Errorf("vote %d references post ID %d but was retrieved for post %d", vote.ID, vote.PostID, postID)
}
if _, exists := postIDSet[vote.PostID]; !exists {
return fmt.Errorf("vote %d references non-existent post ID %d", vote.ID, vote.PostID)
}
if vote.UserID != nil {
if _, exists := userIDSet[*vote.UserID]; !exists {
return fmt.Errorf("vote %d references non-existent user ID %d", vote.ID, *vote.UserID)
}
}
if vote.Type != database.VoteUp && vote.Type != database.VoteDown {
return fmt.Errorf("vote %d has invalid type %q", vote.ID, vote.Type)
}
}

View File

@@ -47,7 +47,7 @@ func TestSeedCommand(t *testing.T) {
var seedUser *database.User
regularUserCount := 0
for i := range users {
if strings.HasPrefix(users[i].Username, "seed_admin_") {
if users[i].Username == "seed_admin" {
seedUserCount++
seedUser = &users[i]
} else if strings.HasPrefix(users[i].Username, "user_") {
@@ -63,12 +63,12 @@ func TestSeedCommand(t *testing.T) {
t.Fatal("Expected seed user to be created")
}
if !strings.HasPrefix(seedUser.Username, "seed_admin_") {
t.Errorf("Expected username to start with 'seed_admin_', got '%s'", seedUser.Username)
if seedUser.Username != "seed_admin" {
t.Errorf("Expected username to be 'seed_admin', got '%s'", seedUser.Username)
}
if !strings.HasPrefix(seedUser.Email, "seed_admin_") || !strings.HasSuffix(seedUser.Email, "@goyco.local") {
t.Errorf("Expected email to start with 'seed_admin_' and end with '@goyco.local', got '%s'", seedUser.Email)
if seedUser.Email != "seed_admin@goyco.local" {
t.Errorf("Expected email to be 'seed_admin@goyco.local', got '%s'", seedUser.Email)
}
if !seedUser.EmailVerified {
@@ -302,13 +302,13 @@ func TestSeedCommandIdempotency(t *testing.T) {
seedUserCount := 0
for _, user := range users {
if strings.HasPrefix(user.Username, "seed_admin_") {
if user.Username == "seed_admin" {
seedUserCount++
}
}
if seedUserCount < 1 {
t.Errorf("Expected at least 1 seed user, got %d", seedUserCount)
if seedUserCount != 1 {
t.Errorf("Expected exactly 1 seed user, got %d", seedUserCount)
}
})
@@ -387,7 +387,7 @@ func TestSeedCommandIdempotency(t *testing.T) {
func findSeedUser(users []database.User) *database.User {
for i := range users {
if strings.HasPrefix(users[i].Username, "seed_admin_") {
if users[i].Username == "seed_admin" {
return &users[i]
}
}
@@ -476,3 +476,58 @@ func TestSeedCommandTransactionRollback(t *testing.T) {
}
})
}
func TestEnsureSeedUser(t *testing.T) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("Failed to connect to database: %v", err)
}
if err := db.AutoMigrate(&database.User{}); err != nil {
t.Fatalf("Failed to migrate database: %v", err)
}
userRepo := repositories.NewUserRepository(db)
passwordHash := "test_password_hash"
firstUser, err := ensureSeedUser(userRepo, passwordHash)
if err != nil {
t.Fatalf("Failed to create seed user: %v", err)
}
if firstUser.Username != "seed_admin" || firstUser.Email != "seed_admin@goyco.local" || firstUser.Password != passwordHash || !firstUser.EmailVerified {
t.Errorf("Invalid seed user: username=%s, email=%s, password matches=%v, emailVerified=%v",
firstUser.Username, firstUser.Email, firstUser.Password == passwordHash, firstUser.EmailVerified)
}
secondUser, err := ensureSeedUser(userRepo, "different_password_hash")
if err != nil {
t.Fatalf("Failed to reuse seed user: %v", err)
}
if firstUser.ID != secondUser.ID {
t.Errorf("Expected same user to be reused (ID %d), got different user (ID %d)", firstUser.ID, secondUser.ID)
}
for i := 0; i < 3; i++ {
if _, err := ensureSeedUser(userRepo, passwordHash); err != nil {
t.Fatalf("Call %d failed: %v", i+1, err)
}
}
users, err := userRepo.GetAll(100, 0)
if err != nil {
t.Fatalf("Failed to get users: %v", err)
}
seedUserCount := 0
for _, user := range users {
if user.Username == "seed_admin" {
seedUserCount++
}
}
if seedUserCount != 1 {
t.Errorf("Expected exactly 1 seed user, got %d", seedUserCount)
}
}

View File

@@ -96,21 +96,21 @@ func TestServerConfigurationFromConfig(t *testing.T) {
testServer := httptest.NewServer(srv.Handler)
defer testServer.Close()
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil)
request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
resp, err := http.DefaultClient.Do(req)
response, err := http.DefaultClient.Do(request)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
defer func() {
_ = resp.Body.Close()
_ = response.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
if response.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", response.StatusCode)
}
}
@@ -208,27 +208,27 @@ func TestTLSWiringFromConfig(t *testing.T) {
},
}
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil)
request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
resp, err := client.Do(req)
response, err := client.Do(request)
if err != nil {
t.Fatalf("Failed to make TLS request: %v", err)
}
defer func() {
_ = resp.Body.Close()
_ = response.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 over TLS, got %d", resp.StatusCode)
if response.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 over TLS, got %d", response.StatusCode)
}
if resp.TLS == nil {
if response.TLS == nil {
t.Error("Expected TLS connection info to be present in response")
} else if resp.TLS.Version < tls.VersionTLS12 {
t.Errorf("Expected TLS version 1.2 or higher, got %x", resp.TLS.Version)
} else if response.TLS.Version < tls.VersionTLS12 {
t.Errorf("Expected TLS version 1.2 or higher, got %x", response.TLS.Version)
}
}
}
@@ -368,38 +368,38 @@ func TestServerInitializationFlow(t *testing.T) {
testServer := httptest.NewServer(srv.Handler)
defer testServer.Close()
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil)
request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
resp, err := http.DefaultClient.Do(req)
response, err := http.DefaultClient.Do(request)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
defer func() {
_ = resp.Body.Close()
_ = response.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
if response.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", response.StatusCode)
}
req, err = http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/api", nil)
request, err = http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/api", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
resp, err = http.DefaultClient.Do(req)
response, err = http.DefaultClient.Do(request)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
defer func() {
_ = resp.Body.Close()
_ = response.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 for API endpoint, got %d", resp.StatusCode)
if response.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 for API endpoint, got %d", response.StatusCode)
}
}

View File

@@ -111,7 +111,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ConfirmAccountDeletionRequest"
"$ref": "#/definitions/dto.ConfirmAccountDeletionRequest"
}
}
],
@@ -212,7 +212,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdateEmailRequest"
"$ref": "#/definitions/dto.UpdateEmailRequest"
}
}
],
@@ -276,7 +276,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ForgotPasswordRequest"
"$ref": "#/definitions/dto.ForgotPasswordRequest"
}
}
],
@@ -316,7 +316,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.LoginRequest"
"$ref": "#/definitions/dto.LoginRequest"
}
}
],
@@ -453,7 +453,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdatePasswordRequest"
"$ref": "#/definitions/dto.UpdatePasswordRequest"
}
}
],
@@ -505,7 +505,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RefreshTokenRequest"
"$ref": "#/definitions/dto.RefreshTokenRequest"
}
}
],
@@ -563,7 +563,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RegisterRequest"
"$ref": "#/definitions/dto.RegisterRequest"
}
}
],
@@ -615,7 +615,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ResendVerificationRequest"
"$ref": "#/definitions/dto.ResendVerificationRequest"
}
}
],
@@ -685,7 +685,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ResetPasswordRequest"
"$ref": "#/definitions/dto.ResetPasswordRequest"
}
}
],
@@ -736,7 +736,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RevokeTokenRequest"
"$ref": "#/definitions/dto.RevokeTokenRequest"
}
}
],
@@ -833,7 +833,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdateUsernameRequest"
"$ref": "#/definitions/dto.UpdateUsernameRequest"
}
}
],
@@ -945,7 +945,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.CreatePostRequest"
"$ref": "#/definitions/dto.CreatePostRequest"
}
}
],
@@ -1176,7 +1176,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdatePostRequest"
"$ref": "#/definitions/dto.UpdatePostRequest"
}
}
],
@@ -1370,7 +1370,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.VoteRequest"
"$ref": "#/definitions/dto.CastVoteRequest"
}
}
],
@@ -1601,7 +1601,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RegisterRequest"
"$ref": "#/definitions/dto.RegisterRequest"
}
}
],
@@ -1817,6 +1817,223 @@ const docTemplate = `{
}
},
"definitions": {
"dto.CastVoteRequest": {
"type": "object",
"required": [
"type"
],
"properties": {
"type": {
"type": "string",
"enum": [
"up",
"down",
"none"
]
}
}
},
"dto.ConfirmAccountDeletionRequest": {
"type": "object",
"required": [
"token"
],
"properties": {
"delete_posts": {
"type": "boolean"
},
"token": {
"type": "string"
}
}
},
"dto.CreatePostRequest": {
"type": "object",
"required": [
"url"
],
"properties": {
"content": {
"type": "string",
"maxLength": 10000
},
"title": {
"type": "string",
"maxLength": 200,
"minLength": 3
},
"url": {
"type": "string",
"maxLength": 2048
}
}
},
"dto.ForgotPasswordRequest": {
"type": "object",
"required": [
"username_or_email"
],
"properties": {
"username_or_email": {
"type": "string"
}
}
},
"dto.LoginRequest": {
"type": "object",
"required": [
"password",
"username"
],
"properties": {
"password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"dto.RefreshTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"dto.RegisterRequest": {
"type": "object",
"required": [
"email",
"password",
"username"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
},
"password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"dto.ResendVerificationRequest": {
"type": "object",
"required": [
"email"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
}
}
},
"dto.ResetPasswordRequest": {
"type": "object",
"required": [
"new_password",
"token"
],
"properties": {
"new_password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"token": {
"type": "string"
}
}
},
"dto.RevokeTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"dto.UpdateEmailRequest": {
"type": "object",
"required": [
"email"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
}
}
},
"dto.UpdatePasswordRequest": {
"type": "object",
"required": [
"current_password",
"new_password"
],
"properties": {
"current_password": {
"type": "string"
},
"new_password": {
"type": "string",
"maxLength": 128,
"minLength": 8
}
}
},
"dto.UpdatePostRequest": {
"type": "object",
"required": [
"title"
],
"properties": {
"content": {
"type": "string",
"maxLength": 10000
},
"title": {
"type": "string",
"maxLength": 200,
"minLength": 3
}
}
},
"dto.UpdateUsernameRequest": {
"type": "object",
"required": [
"username"
],
"properties": {
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"handlers.APIInfo": {
"type": "object",
"properties": {
@@ -1919,50 +2136,6 @@ const docTemplate = `{
}
}
},
"handlers.ConfirmAccountDeletionRequest": {
"type": "object",
"properties": {
"delete_posts": {
"type": "boolean"
},
"token": {
"type": "string"
}
}
},
"handlers.CreatePostRequest": {
"type": "object",
"properties": {
"content": {
"type": "string"
},
"title": {
"type": "string"
},
"url": {
"type": "string"
}
}
},
"handlers.ForgotPasswordRequest": {
"type": "object",
"properties": {
"username_or_email": {
"type": "string"
}
}
},
"handlers.LoginRequest": {
"type": "object",
"properties": {
"password": {
"type": "string"
},
"username": {
"type": "string"
}
}
},
"handlers.PostResponse": {
"type": "object",
"properties": {
@@ -1978,101 +2151,6 @@ const docTemplate = `{
}
}
},
"handlers.RefreshTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"handlers.RegisterRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
},
"password": {
"type": "string"
},
"username": {
"type": "string"
}
}
},
"handlers.ResendVerificationRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
}
}
},
"handlers.ResetPasswordRequest": {
"type": "object",
"properties": {
"new_password": {
"type": "string"
},
"token": {
"type": "string"
}
}
},
"handlers.RevokeTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"handlers.UpdateEmailRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
}
}
},
"handlers.UpdatePasswordRequest": {
"type": "object",
"properties": {
"current_password": {
"type": "string"
},
"new_password": {
"type": "string"
}
}
},
"handlers.UpdatePostRequest": {
"type": "object",
"properties": {
"content": {
"type": "string"
},
"title": {
"type": "string"
}
}
},
"handlers.UpdateUsernameRequest": {
"type": "object",
"properties": {
"username": {
"type": "string"
}
}
},
"handlers.UserResponse": {
"type": "object",
"properties": {
@@ -2088,21 +2166,6 @@ const docTemplate = `{
}
}
},
"handlers.VoteRequest": {
"description": "Vote request with type field. All votes are handled the same way.",
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": [
"up",
"down",
"none"
],
"example": "up"
}
}
},
"handlers.VoteResponse": {
"type": "object",
"properties": {

View File

@@ -108,7 +108,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ConfirmAccountDeletionRequest"
"$ref": "#/definitions/dto.ConfirmAccountDeletionRequest"
}
}
],
@@ -209,7 +209,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdateEmailRequest"
"$ref": "#/definitions/dto.UpdateEmailRequest"
}
}
],
@@ -273,7 +273,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ForgotPasswordRequest"
"$ref": "#/definitions/dto.ForgotPasswordRequest"
}
}
],
@@ -313,7 +313,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.LoginRequest"
"$ref": "#/definitions/dto.LoginRequest"
}
}
],
@@ -450,7 +450,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdatePasswordRequest"
"$ref": "#/definitions/dto.UpdatePasswordRequest"
}
}
],
@@ -502,7 +502,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RefreshTokenRequest"
"$ref": "#/definitions/dto.RefreshTokenRequest"
}
}
],
@@ -560,7 +560,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RegisterRequest"
"$ref": "#/definitions/dto.RegisterRequest"
}
}
],
@@ -612,7 +612,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ResendVerificationRequest"
"$ref": "#/definitions/dto.ResendVerificationRequest"
}
}
],
@@ -682,7 +682,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ResetPasswordRequest"
"$ref": "#/definitions/dto.ResetPasswordRequest"
}
}
],
@@ -733,7 +733,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RevokeTokenRequest"
"$ref": "#/definitions/dto.RevokeTokenRequest"
}
}
],
@@ -830,7 +830,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdateUsernameRequest"
"$ref": "#/definitions/dto.UpdateUsernameRequest"
}
}
],
@@ -942,7 +942,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.CreatePostRequest"
"$ref": "#/definitions/dto.CreatePostRequest"
}
}
],
@@ -1173,7 +1173,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdatePostRequest"
"$ref": "#/definitions/dto.UpdatePostRequest"
}
}
],
@@ -1367,7 +1367,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.VoteRequest"
"$ref": "#/definitions/dto.CastVoteRequest"
}
}
],
@@ -1598,7 +1598,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RegisterRequest"
"$ref": "#/definitions/dto.RegisterRequest"
}
}
],
@@ -1814,6 +1814,223 @@
}
},
"definitions": {
"dto.CastVoteRequest": {
"type": "object",
"required": [
"type"
],
"properties": {
"type": {
"type": "string",
"enum": [
"up",
"down",
"none"
]
}
}
},
"dto.ConfirmAccountDeletionRequest": {
"type": "object",
"required": [
"token"
],
"properties": {
"delete_posts": {
"type": "boolean"
},
"token": {
"type": "string"
}
}
},
"dto.CreatePostRequest": {
"type": "object",
"required": [
"url"
],
"properties": {
"content": {
"type": "string",
"maxLength": 10000
},
"title": {
"type": "string",
"maxLength": 200,
"minLength": 3
},
"url": {
"type": "string",
"maxLength": 2048
}
}
},
"dto.ForgotPasswordRequest": {
"type": "object",
"required": [
"username_or_email"
],
"properties": {
"username_or_email": {
"type": "string"
}
}
},
"dto.LoginRequest": {
"type": "object",
"required": [
"password",
"username"
],
"properties": {
"password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"dto.RefreshTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"dto.RegisterRequest": {
"type": "object",
"required": [
"email",
"password",
"username"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
},
"password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"dto.ResendVerificationRequest": {
"type": "object",
"required": [
"email"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
}
}
},
"dto.ResetPasswordRequest": {
"type": "object",
"required": [
"new_password",
"token"
],
"properties": {
"new_password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"token": {
"type": "string"
}
}
},
"dto.RevokeTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"dto.UpdateEmailRequest": {
"type": "object",
"required": [
"email"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
}
}
},
"dto.UpdatePasswordRequest": {
"type": "object",
"required": [
"current_password",
"new_password"
],
"properties": {
"current_password": {
"type": "string"
},
"new_password": {
"type": "string",
"maxLength": 128,
"minLength": 8
}
}
},
"dto.UpdatePostRequest": {
"type": "object",
"required": [
"title"
],
"properties": {
"content": {
"type": "string",
"maxLength": 10000
},
"title": {
"type": "string",
"maxLength": 200,
"minLength": 3
}
}
},
"dto.UpdateUsernameRequest": {
"type": "object",
"required": [
"username"
],
"properties": {
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"handlers.APIInfo": {
"type": "object",
"properties": {
@@ -1916,50 +2133,6 @@
}
}
},
"handlers.ConfirmAccountDeletionRequest": {
"type": "object",
"properties": {
"delete_posts": {
"type": "boolean"
},
"token": {
"type": "string"
}
}
},
"handlers.CreatePostRequest": {
"type": "object",
"properties": {
"content": {
"type": "string"
},
"title": {
"type": "string"
},
"url": {
"type": "string"
}
}
},
"handlers.ForgotPasswordRequest": {
"type": "object",
"properties": {
"username_or_email": {
"type": "string"
}
}
},
"handlers.LoginRequest": {
"type": "object",
"properties": {
"password": {
"type": "string"
},
"username": {
"type": "string"
}
}
},
"handlers.PostResponse": {
"type": "object",
"properties": {
@@ -1975,101 +2148,6 @@
}
}
},
"handlers.RefreshTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"handlers.RegisterRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
},
"password": {
"type": "string"
},
"username": {
"type": "string"
}
}
},
"handlers.ResendVerificationRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
}
}
},
"handlers.ResetPasswordRequest": {
"type": "object",
"properties": {
"new_password": {
"type": "string"
},
"token": {
"type": "string"
}
}
},
"handlers.RevokeTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"handlers.UpdateEmailRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
}
}
},
"handlers.UpdatePasswordRequest": {
"type": "object",
"properties": {
"current_password": {
"type": "string"
},
"new_password": {
"type": "string"
}
}
},
"handlers.UpdatePostRequest": {
"type": "object",
"properties": {
"content": {
"type": "string"
},
"title": {
"type": "string"
}
}
},
"handlers.UpdateUsernameRequest": {
"type": "object",
"properties": {
"username": {
"type": "string"
}
}
},
"handlers.UserResponse": {
"type": "object",
"properties": {
@@ -2085,21 +2163,6 @@
}
}
},
"handlers.VoteRequest": {
"description": "Vote request with type field. All votes are handled the same way.",
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": [
"up",
"down",
"none"
],
"example": "up"
}
}
},
"handlers.VoteResponse": {
"type": "object",
"properties": {

View File

@@ -1,5 +1,156 @@
basePath: /api
definitions:
dto.CastVoteRequest:
properties:
type:
enum:
- up
- down
- none
type: string
required:
- type
type: object
dto.ConfirmAccountDeletionRequest:
properties:
delete_posts:
type: boolean
token:
type: string
required:
- token
type: object
dto.CreatePostRequest:
properties:
content:
maxLength: 10000
type: string
title:
maxLength: 200
minLength: 3
type: string
url:
maxLength: 2048
type: string
required:
- url
type: object
dto.ForgotPasswordRequest:
properties:
username_or_email:
type: string
required:
- username_or_email
type: object
dto.LoginRequest:
properties:
password:
maxLength: 128
minLength: 8
type: string
username:
maxLength: 50
minLength: 3
type: string
required:
- password
- username
type: object
dto.RefreshTokenRequest:
properties:
refresh_token:
example: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
type: string
required:
- refresh_token
type: object
dto.RegisterRequest:
properties:
email:
maxLength: 254
type: string
password:
maxLength: 128
minLength: 8
type: string
username:
maxLength: 50
minLength: 3
type: string
required:
- email
- password
- username
type: object
dto.ResendVerificationRequest:
properties:
email:
maxLength: 254
type: string
required:
- email
type: object
dto.ResetPasswordRequest:
properties:
new_password:
maxLength: 128
minLength: 8
type: string
token:
type: string
required:
- new_password
- token
type: object
dto.RevokeTokenRequest:
properties:
refresh_token:
example: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
type: string
required:
- refresh_token
type: object
dto.UpdateEmailRequest:
properties:
email:
maxLength: 254
type: string
required:
- email
type: object
dto.UpdatePasswordRequest:
properties:
current_password:
type: string
new_password:
maxLength: 128
minLength: 8
type: string
required:
- current_password
- new_password
type: object
dto.UpdatePostRequest:
properties:
content:
maxLength: 10000
type: string
title:
maxLength: 200
minLength: 3
type: string
required:
- title
type: object
dto.UpdateUsernameRequest:
properties:
username:
maxLength: 50
minLength: 3
type: string
required:
- username
type: object
handlers.APIInfo:
properties:
data: {}
@@ -70,34 +221,6 @@ definitions:
success:
type: boolean
type: object
handlers.ConfirmAccountDeletionRequest:
properties:
delete_posts:
type: boolean
token:
type: string
type: object
handlers.CreatePostRequest:
properties:
content:
type: string
title:
type: string
url:
type: string
type: object
handlers.ForgotPasswordRequest:
properties:
username_or_email:
type: string
type: object
handlers.LoginRequest:
properties:
password:
type: string
username:
type: string
type: object
handlers.PostResponse:
properties:
data: {}
@@ -108,67 +231,6 @@ definitions:
success:
type: boolean
type: object
handlers.RefreshTokenRequest:
properties:
refresh_token:
example: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
type: string
required:
- refresh_token
type: object
handlers.RegisterRequest:
properties:
email:
type: string
password:
type: string
username:
type: string
type: object
handlers.ResendVerificationRequest:
properties:
email:
type: string
type: object
handlers.ResetPasswordRequest:
properties:
new_password:
type: string
token:
type: string
type: object
handlers.RevokeTokenRequest:
properties:
refresh_token:
example: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
type: string
required:
- refresh_token
type: object
handlers.UpdateEmailRequest:
properties:
email:
type: string
type: object
handlers.UpdatePasswordRequest:
properties:
current_password:
type: string
new_password:
type: string
type: object
handlers.UpdatePostRequest:
properties:
content:
type: string
title:
type: string
type: object
handlers.UpdateUsernameRequest:
properties:
username:
type: string
type: object
handlers.UserResponse:
properties:
data: {}
@@ -179,17 +241,6 @@ definitions:
success:
type: boolean
type: object
handlers.VoteRequest:
description: Vote request with type field. All votes are handled the same way.
properties:
type:
enum:
- up
- down
- none
example: up
type: string
type: object
handlers.VoteResponse:
properties:
data: {}
@@ -268,7 +319,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.ConfirmAccountDeletionRequest'
$ref: '#/definitions/dto.ConfirmAccountDeletionRequest'
produces:
- application/json
responses:
@@ -331,7 +382,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.UpdateEmailRequest'
$ref: '#/definitions/dto.UpdateEmailRequest'
produces:
- application/json
responses:
@@ -375,7 +426,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.ForgotPasswordRequest'
$ref: '#/definitions/dto.ForgotPasswordRequest'
produces:
- application/json
responses:
@@ -401,7 +452,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.LoginRequest'
$ref: '#/definitions/dto.LoginRequest'
produces:
- application/json
responses:
@@ -485,7 +536,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.UpdatePasswordRequest'
$ref: '#/definitions/dto.UpdatePasswordRequest'
produces:
- application/json
responses:
@@ -523,7 +574,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.RefreshTokenRequest'
$ref: '#/definitions/dto.RefreshTokenRequest'
produces:
- application/json
responses:
@@ -561,7 +612,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.RegisterRequest'
$ref: '#/definitions/dto.RegisterRequest'
produces:
- application/json
responses:
@@ -595,7 +646,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.ResendVerificationRequest'
$ref: '#/definitions/dto.ResendVerificationRequest'
produces:
- application/json
responses:
@@ -641,7 +692,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.ResetPasswordRequest'
$ref: '#/definitions/dto.ResetPasswordRequest'
produces:
- application/json
responses:
@@ -672,7 +723,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.RevokeTokenRequest'
$ref: '#/definitions/dto.RevokeTokenRequest'
produces:
- application/json
responses:
@@ -735,7 +786,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.UpdateUsernameRequest'
$ref: '#/definitions/dto.UpdateUsernameRequest'
produces:
- application/json
responses:
@@ -809,7 +860,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.CreatePostRequest'
$ref: '#/definitions/dto.CreatePostRequest'
produces:
- application/json
responses:
@@ -933,7 +984,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.UpdatePostRequest'
$ref: '#/definitions/dto.UpdatePostRequest'
produces:
- application/json
responses:
@@ -1070,7 +1121,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.VoteRequest'
$ref: '#/definitions/dto.CastVoteRequest'
produces:
- application/json
responses:
@@ -1260,7 +1311,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.RegisterRequest'
$ref: '#/definitions/dto.RegisterRequest'
produces:
- application/json
responses:

View File

@@ -0,0 +1,51 @@
package dto
type LoginRequest struct {
Username string `json:"username" validate:"required,min=3,max=50"`
Password string `json:"password" validate:"required,min=8,max=128"`
}
type RegisterRequest struct {
Username string `json:"username" validate:"required,min=3,max=50"`
Email string `json:"email" validate:"required,email,max=254"`
Password string `json:"password" validate:"required,min=8,max=128"`
}
type ResendVerificationRequest struct {
Email string `json:"email" validate:"required,email,max=254"`
}
type ForgotPasswordRequest struct {
UsernameOrEmail string `json:"username_or_email" validate:"required"`
}
type ResetPasswordRequest struct {
Token string `json:"token" validate:"required"`
NewPassword string `json:"new_password" validate:"required,min=8,max=128"`
}
type UpdateEmailRequest struct {
Email string `json:"email" validate:"required,email,max=254"`
}
type UpdateUsernameRequest struct {
Username string `json:"username" validate:"required,min=3,max=50"`
}
type UpdatePasswordRequest struct {
CurrentPassword string `json:"current_password" validate:"required"`
NewPassword string `json:"new_password" validate:"required,min=8,max=128"`
}
type ConfirmAccountDeletionRequest struct {
Token string `json:"token" validate:"required"`
DeletePosts bool `json:"delete_posts"`
}
type RefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." validate:"required"`
}
type RevokeTokenRequest struct {
RefreshToken string `json:"refresh_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." validate:"required"`
}

View File

@@ -0,0 +1,12 @@
package dto
type CreatePostRequest struct {
Title string `json:"title" validate:"omitempty,min=3,max=200"`
URL string `json:"url" validate:"required,url,max=2048"`
Content string `json:"content" validate:"omitempty,max=10000"`
}
type UpdatePostRequest struct {
Title string `json:"title" validate:"required,min=3,max=200"`
Content string `json:"content" validate:"omitempty,max=10000"`
}

View File

@@ -6,6 +6,10 @@ import (
"goyco/internal/database"
)
type CastVoteRequest struct {
Type string `json:"type" validate:"required,oneof=up down none"`
}
type VoteDTO struct {
ID uint `json:"id"`
UserID *uint `json:"user_id,omitempty"`

View File

@@ -112,37 +112,37 @@ func newInMemoryRoundTripper(handler http.Handler) http.RoundTripper {
return &inMemoryRoundTripper{handler: handler}
}
func (rt *inMemoryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
func (rt *inMemoryRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
if rt == nil || rt.handler == nil {
return nil, fmt.Errorf("in-memory round tripper not initialized")
}
var bodyBytes []byte
if req.Body != nil && req.Body != http.NoBody {
defer req.Body.Close()
if request.Body != nil && request.Body != http.NoBody {
defer request.Body.Close()
var err error
bodyBytes, err = io.ReadAll(req.Body)
bodyBytes, err = io.ReadAll(request.Body)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %w", err)
}
}
if len(bodyBytes) > 0 {
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
} else {
req.Body = http.NoBody
request.Body = http.NoBody
}
clonedReq := req.Clone(req.Context())
clonedRequest := request.Clone(request.Context())
if len(bodyBytes) > 0 {
clonedReq.Body = io.NopCloser(bytes.NewReader(bodyBytes))
clonedRequest.Body = io.NopCloser(bytes.NewReader(bodyBytes))
} else {
clonedReq.Body = http.NoBody
clonedRequest.Body = http.NoBody
}
clonedReq.RequestURI = clonedReq.URL.RequestURI()
clonedRequest.RequestURI = clonedRequest.URL.RequestURI()
recorder := httptest.NewRecorder()
rt.handler.ServeHTTP(recorder, clonedReq)
rt.handler.ServeHTTP(recorder, clonedRequest)
resp := recorder.Result()
return resp, nil
}
@@ -250,7 +250,7 @@ func tokenHash(token string) string {
func retryOnRateLimit(t *testing.T, maxRetries int, operation func() int) int {
t.Helper()
for attempt := 0; attempt < maxRetries; attempt++ {
for attempt := range maxRetries {
statusCode := operation()
if statusCode != http.StatusTooManyRequests {
return statusCode

View File

@@ -15,22 +15,22 @@ func TestE2E_CompressionMiddleware(t *testing.T) {
ctx := setupTestContext(t)
t.Run("compression_enabled_with_accept_encoding", func(t *testing.T) {
req, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
request, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Accept-Encoding", "gzip")
testutils.WithStandardHeaders(req)
request.Header.Set("Accept-Encoding", "gzip")
testutils.WithStandardHeaders(request)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
defer response.Body.Close()
contentEncoding := resp.Header.Get("Content-Encoding")
contentEncoding := response.Header.Get("Content-Encoding")
if contentEncoding == "gzip" {
body, err := io.ReadAll(resp.Body)
body, err := io.ReadAll(response.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
@@ -57,19 +57,19 @@ func TestE2E_CompressionMiddleware(t *testing.T) {
})
t.Run("no_compression_without_accept_encoding", func(t *testing.T) {
req, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
request, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
testutils.WithStandardHeaders(req)
testutils.WithStandardHeaders(request)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
defer response.Body.Close()
contentEncoding := resp.Header.Get("Content-Encoding")
contentEncoding := response.Header.Get("Content-Encoding")
if contentEncoding == "gzip" {
t.Error("Expected no compression without Accept-Encoding header")
}
@@ -85,22 +85,22 @@ func TestE2E_CompressionMiddleware(t *testing.T) {
gz.Write([]byte(postData))
gz.Close()
req, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", &buf)
request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", &buf)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Encoding", "gzip")
testutils.WithStandardHeaders(req)
req.Header.Set("Authorization", "Bearer "+authClient.Token)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Content-Encoding", "gzip")
testutils.WithStandardHeaders(request)
request.Header.Set("Authorization", "Bearer "+authClient.Token)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
defer response.Body.Close()
switch resp.StatusCode {
switch response.StatusCode {
case http.StatusBadRequest:
t.Log("Decompression middleware rejected invalid gzip")
case http.StatusCreated, http.StatusOK:
@@ -113,37 +113,37 @@ func TestE2E_CacheMiddleware(t *testing.T) {
ctx := setupTestContext(t)
t.Run("cache_miss_then_hit", func(t *testing.T) {
req1, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
firstRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
testutils.WithStandardHeaders(req1)
testutils.WithStandardHeaders(firstRequest)
resp1, err := ctx.client.Do(req1)
firstResponse, err := ctx.client.Do(firstRequest)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
resp1.Body.Close()
firstResponse.Body.Close()
cacheStatus1 := resp1.Header.Get("X-Cache")
if cacheStatus1 == "HIT" {
firstCacheStatus := firstResponse.Header.Get("X-Cache")
if firstCacheStatus == "HIT" {
t.Log("First request was cached (unexpected but acceptable)")
}
req2, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
secondRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
testutils.WithStandardHeaders(req2)
testutils.WithStandardHeaders(secondRequest)
resp2, err := ctx.client.Do(req2)
secondResponse, err := ctx.client.Do(secondRequest)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp2.Body.Close()
defer secondResponse.Body.Close()
cacheStatus2 := resp2.Header.Get("X-Cache")
if cacheStatus2 == "HIT" {
secondCacheStatus := secondResponse.Header.Get("X-Cache")
if secondCacheStatus == "HIT" {
t.Log("Second request was served from cache")
}
})
@@ -152,48 +152,48 @@ func TestE2E_CacheMiddleware(t *testing.T) {
testUser := ctx.createUserWithCleanup(t, "cacheuser", "StrongPass123!")
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
req1, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
firstRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
testutils.WithStandardHeaders(req1)
req1.Header.Set("Authorization", "Bearer "+authClient.Token)
testutils.WithStandardHeaders(firstRequest)
firstRequest.Header.Set("Authorization", "Bearer "+authClient.Token)
resp1, err := ctx.client.Do(req1)
firstResponse, err := ctx.client.Do(firstRequest)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
resp1.Body.Close()
firstResponse.Body.Close()
postData := `{"title":"Cache Invalidation Test","url":"https://example.com/cache","content":"Test"}`
req2, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
secondRequest, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req2.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req2)
req2.Header.Set("Authorization", "Bearer "+authClient.Token)
secondRequest.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(secondRequest)
secondRequest.Header.Set("Authorization", "Bearer "+authClient.Token)
resp2, err := ctx.client.Do(req2)
secondResponse, err := ctx.client.Do(secondRequest)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
resp2.Body.Close()
secondResponse.Body.Close()
req3, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
thirdRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
testutils.WithStandardHeaders(req3)
req3.Header.Set("Authorization", "Bearer "+authClient.Token)
testutils.WithStandardHeaders(thirdRequest)
thirdRequest.Header.Set("Authorization", "Bearer "+authClient.Token)
resp3, err := ctx.client.Do(req3)
thirdResponse, err := ctx.client.Do(thirdRequest)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp3.Body.Close()
defer thirdResponse.Body.Close()
cacheStatus := resp3.Header.Get("X-Cache")
cacheStatus := thirdResponse.Header.Get("X-Cache")
if cacheStatus == "HIT" {
t.Log("Cache was invalidated after POST")
}
@@ -204,23 +204,23 @@ func TestE2E_CSRFProtection(t *testing.T) {
ctx := setupTestContext(t)
t.Run("csrf_protection_for_non_api_routes", func(t *testing.T) {
req, err := http.NewRequest("POST", ctx.baseURL+"/auth/login", strings.NewReader(`{"username":"test","password":"test"}`))
request, err := http.NewRequest("POST", ctx.baseURL+"/auth/login", strings.NewReader(`{"username":"test","password":"test"}`))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req)
request.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(request)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
defer response.Body.Close()
if resp.StatusCode == http.StatusForbidden {
if response.StatusCode == http.StatusForbidden {
t.Log("CSRF protection active for non-API routes")
} else {
t.Logf("CSRF check result: status %d", resp.StatusCode)
t.Logf("CSRF check result: status %d", response.StatusCode)
}
})
@@ -229,39 +229,39 @@ func TestE2E_CSRFProtection(t *testing.T) {
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
postData := `{"title":"CSRF Test","url":"https://example.com/csrf","content":"Test"}`
req, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req)
req.Header.Set("Authorization", "Bearer "+authClient.Token)
request.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(request)
request.Header.Set("Authorization", "Bearer "+authClient.Token)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
defer response.Body.Close()
if resp.StatusCode == http.StatusForbidden {
if response.StatusCode == http.StatusForbidden {
t.Error("API routes should bypass CSRF protection")
}
})
t.Run("csrf_allows_get_requests", func(t *testing.T) {
req, err := http.NewRequest("GET", ctx.baseURL+"/auth/login", nil)
request, err := http.NewRequest("GET", ctx.baseURL+"/auth/login", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
testutils.WithStandardHeaders(req)
testutils.WithStandardHeaders(request)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
defer response.Body.Close()
if resp.StatusCode == http.StatusForbidden {
if response.StatusCode == http.StatusForbidden {
t.Error("GET requests should not require CSRF token")
}
})
@@ -276,21 +276,21 @@ func TestE2E_RequestSizeLimit(t *testing.T) {
smallData := strings.Repeat("a", 100)
postData := `{"title":"` + smallData + `","url":"https://example.com","content":"test"}`
req, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req)
req.Header.Set("Authorization", "Bearer "+authClient.Token)
request.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(request)
request.Header.Set("Authorization", "Bearer "+authClient.Token)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
defer response.Body.Close()
if resp.StatusCode == http.StatusRequestEntityTooLarge {
if response.StatusCode == http.StatusRequestEntityTooLarge {
t.Error("Small request should not exceed size limit")
}
})
@@ -301,24 +301,24 @@ func TestE2E_RequestSizeLimit(t *testing.T) {
largeData := strings.Repeat("a", 2*1024*1024)
postData := `{"title":"test","url":"https://example.com","content":"` + largeData + `"}`
req, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req)
req.Header.Set("Authorization", "Bearer "+authClient.Token)
request.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(request)
request.Header.Set("Authorization", "Bearer "+authClient.Token)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
return
}
defer resp.Body.Close()
defer response.Body.Close()
if resp.StatusCode == http.StatusRequestEntityTooLarge {
if response.StatusCode == http.StatusRequestEntityTooLarge {
t.Log("Request size limit enforced correctly")
} else {
t.Logf("Request size limit check result: status %d", resp.StatusCode)
t.Logf("Request size limit check result: status %d", response.StatusCode)
}
})
}

View File

@@ -64,62 +64,6 @@ type AuthUserSummary struct {
Locked bool `json:"locked" example:"false"`
}
type LoginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
}
type RegisterRequest struct {
Username string `json:"username"`
Email string `json:"email"`
Password string `json:"password"`
}
type CreatePostRequest struct {
Title string `json:"title"`
URL string `json:"url"`
Content string `json:"content"`
}
type ResendVerificationRequest struct {
Email string `json:"email"`
}
type ForgotPasswordRequest struct {
UsernameOrEmail string `json:"username_or_email"`
}
type ResetPasswordRequest struct {
Token string `json:"token"`
NewPassword string `json:"new_password"`
}
type UpdateEmailRequest struct {
Email string `json:"email"`
}
type UpdateUsernameRequest struct {
Username string `json:"username"`
}
type UpdatePasswordRequest struct {
CurrentPassword string `json:"current_password"`
NewPassword string `json:"new_password"`
}
type ConfirmAccountDeletionRequest struct {
Token string `json:"token"`
DeletePosts bool `json:"delete_posts"`
}
type RefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." binding:"required"`
}
type RevokeTokenRequest struct {
RefreshToken string `json:"refresh_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." binding:"required"`
}
func NewAuthHandler(authService AuthServiceInterface, userRepo repositories.UserRepository) *AuthHandler {
return &AuthHandler{
authService: authService,
@@ -132,7 +76,7 @@ func NewAuthHandler(authService AuthServiceInterface, userRepo repositories.User
// @Tags auth
// @Accept json
// @Produce json
// @Param request body LoginRequest true "Login credentials"
// @Param request body dto.LoginRequest true "Login credentials"
// @Success 200 {object} AuthTokensResponse "Authentication successful"
// @Failure 400 {object} AuthResponse "Invalid request data or validation failed"
// @Failure 401 {object} AuthResponse "Invalid credentials"
@@ -140,23 +84,15 @@ func NewAuthHandler(authService AuthServiceInterface, userRepo repositories.User
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/login [post]
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
var req struct {
Username string `json:"username"`
Password string `json:"password"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.LoginRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
username := security.SanitizeUsername(req.Username)
password := strings.TrimSpace(req.Password)
if username == "" || password == "" {
SendErrorResponse(w, "Username and password are required", http.StatusBadRequest)
return
}
if err := validation.ValidatePassword(password); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
@@ -175,20 +111,16 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
// @Tags auth
// @Accept json
// @Produce json
// @Param request body RegisterRequest true "Registration data"
// @Param request body dto.RegisterRequest true "Registration data"
// @Success 201 {object} AuthResponse "Registration successful"
// @Failure 400 {object} AuthResponse "Invalid request data or validation failed"
// @Failure 409 {object} AuthResponse "Username or email already exists"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/register [post]
func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
var req struct {
Username string `json:"username"`
Email string `json:"email"`
Password string `json:"password"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.RegisterRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -196,11 +128,6 @@ func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
email := strings.TrimSpace(req.Email)
password := strings.TrimSpace(req.Password)
if username == "" || email == "" || password == "" {
SendErrorResponse(w, "Username, email, and password are required", http.StatusBadRequest)
return
}
username = security.SanitizeUsername(username)
if err := validation.ValidateUsername(username); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
@@ -280,7 +207,7 @@ func (h *AuthHandler) ConfirmEmail(w http.ResponseWriter, r *http.Request) {
// @Tags auth
// @Accept json
// @Produce json
// @Param request body ResendVerificationRequest true "Email address"
// @Param request body dto.ResendVerificationRequest true "Email address"
// @Success 200 {object} AuthResponse
// @Failure 400 {object} AuthResponse
// @Failure 404 {object} AuthResponse
@@ -290,15 +217,14 @@ func (h *AuthHandler) ConfirmEmail(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} AuthResponse
// @Router /api/auth/resend-verification [post]
func (h *AuthHandler) ResendVerificationEmail(w http.ResponseWriter, r *http.Request) {
var req struct {
Email string `json:"email"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.ResendVerificationRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
email := strings.TrimSpace(req.Email)
if email == "" {
SendErrorResponse(w, "Email address is required", http.StatusBadRequest)
return
@@ -359,20 +285,19 @@ func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) {
// @Tags auth
// @Accept json
// @Produce json
// @Param request body ForgotPasswordRequest true "Username or email"
// @Param request body dto.ForgotPasswordRequest true "Username or email"
// @Success 200 {object} AuthResponse "Password reset email sent if account exists"
// @Failure 400 {object} AuthResponse "Invalid request data"
// @Router /api/auth/forgot-password [post]
func (h *AuthHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Request) {
var req struct {
UsernameOrEmail string `json:"username_or_email"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.ForgotPasswordRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
usernameOrEmail := strings.TrimSpace(req.UsernameOrEmail)
if usernameOrEmail == "" {
SendErrorResponse(w, "Username or email is required", http.StatusBadRequest)
return
@@ -389,18 +314,15 @@ func (h *AuthHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Reques
// @Tags auth
// @Accept json
// @Produce json
// @Param request body ResetPasswordRequest true "Password reset data"
// @Param request body dto.ResetPasswordRequest true "Password reset data"
// @Success 200 {object} AuthResponse "Password reset successfully"
// @Failure 400 {object} AuthResponse "Invalid or expired token, or validation failed"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/reset-password [post]
func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
var req struct {
Token string `json:"token"`
NewPassword string `json:"new_password"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.ResetPasswordRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -408,17 +330,12 @@ func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
newPassword := strings.TrimSpace(req.NewPassword)
if token == "" {
SendErrorResponse(w, "Reset token is required", http.StatusBadRequest)
SendErrorResponse(w, "Token is required", http.StatusBadRequest)
return
}
if newPassword == "" {
SendErrorResponse(w, "New password is required", http.StatusBadRequest)
return
}
if len(newPassword) < 8 {
SendErrorResponse(w, "Password must be at least 8 characters long", http.StatusBadRequest)
if err := validation.ValidatePassword(newPassword); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
@@ -443,7 +360,7 @@ func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body UpdateEmailRequest true "New email address"
// @Param request body dto.UpdateEmailRequest true "New email address"
// @Success 200 {object} AuthResponse
// @Failure 400 {object} AuthResponse
// @Failure 401 {object} AuthResponse
@@ -457,11 +374,9 @@ func (h *AuthHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
return
}
var req struct {
Email string `json:"email"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.UpdateEmailRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -498,7 +413,7 @@ func (h *AuthHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body UpdateUsernameRequest true "New username"
// @Param request body dto.UpdateUsernameRequest true "New username"
// @Success 200 {object} AuthResponse
// @Failure 400 {object} AuthResponse
// @Failure 401 {object} AuthResponse
@@ -511,11 +426,9 @@ func (h *AuthHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
return
}
var req struct {
Username string `json:"username"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.UpdateUsernameRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -548,7 +461,7 @@ func (h *AuthHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body UpdatePasswordRequest true "Password update data"
// @Param request body dto.UpdatePasswordRequest true "Password update data"
// @Success 200 {object} AuthResponse
// @Failure 400 {object} AuthResponse
// @Failure 401 {object} AuthResponse
@@ -560,12 +473,9 @@ func (h *AuthHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) {
return
}
var req struct {
CurrentPassword string `json:"current_password"`
NewPassword string `json:"new_password"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.UpdatePasswordRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -633,23 +543,21 @@ func (h *AuthHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
// @Tags auth
// @Accept json
// @Produce json
// @Param request body ConfirmAccountDeletionRequest true "Account deletion data"
// @Param request body dto.ConfirmAccountDeletionRequest true "Account deletion data"
// @Success 200 {object} AuthResponse "Account deleted successfully"
// @Failure 400 {object} AuthResponse "Invalid or expired token"
// @Failure 503 {object} AuthResponse "Email delivery unavailable"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/account/confirm [post]
func (h *AuthHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Request) {
var req struct {
Token string `json:"token"`
DeletePosts bool `json:"delete_posts"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.ConfirmAccountDeletionRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
token := strings.TrimSpace(req.Token)
if token == "" {
SendErrorResponse(w, "Deletion token is required", http.StatusBadRequest)
return
@@ -694,7 +602,7 @@ func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
// @Tags auth
// @Accept json
// @Produce json
// @Param request body RefreshTokenRequest true "Refresh token data"
// @Param request body dto.RefreshTokenRequest true "Refresh token data"
// @Success 200 {object} AuthTokensResponse "Token refreshed successfully"
// @Failure 400 {object} AuthResponse "Invalid request body or missing refresh token"
// @Failure 401 {object} AuthResponse "Invalid or expired refresh token"
@@ -702,13 +610,13 @@ func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/refresh [post]
func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
var req RefreshTokenRequest
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.RefreshTokenRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
if strings.TrimSpace(req.RefreshToken) == "" {
if req.RefreshToken == "" {
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
return
}
@@ -727,20 +635,20 @@ func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body RevokeTokenRequest true "Token revocation data"
// @Param request body dto.RevokeTokenRequest true "Token revocation data"
// @Success 200 {object} AuthResponse "Token revoked successfully"
// @Failure 400 {object} AuthResponse "Invalid request body or missing refresh token"
// @Failure 401 {object} AuthResponse "Invalid or expired access token"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/revoke [post]
func (h *AuthHandler) RevokeToken(w http.ResponseWriter, r *http.Request) {
var req RevokeTokenRequest
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.RevokeTokenRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
if strings.TrimSpace(req.RefreshToken) == "" {
if req.RefreshToken == "" {
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
return
}
@@ -782,28 +690,28 @@ func (h *AuthHandler) RevokeAllTokens(w http.ResponseWriter, r *http.Request) {
func (h *AuthHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
if config.GeneralRateLimit != nil {
rateLimited := config.GeneralRateLimit(r)
rateLimited.Post("/auth/refresh", h.RefreshToken)
rateLimited.Post("/auth/refresh", WithValidation[dto.RefreshTokenRequest](config.ValidationMiddleware, h.RefreshToken))
rateLimited.Get("/auth/confirm", h.ConfirmEmail)
rateLimited.Post("/auth/resend-verification", h.ResendVerificationEmail)
rateLimited.Post("/auth/resend-verification", WithValidation[dto.ResendVerificationRequest](config.ValidationMiddleware, h.ResendVerificationEmail))
} else {
r.Post("/auth/refresh", h.RefreshToken)
r.Post("/auth/refresh", WithValidation[dto.RefreshTokenRequest](config.ValidationMiddleware, h.RefreshToken))
r.Get("/auth/confirm", h.ConfirmEmail)
r.Post("/auth/resend-verification", h.ResendVerificationEmail)
r.Post("/auth/resend-verification", WithValidation[dto.ResendVerificationRequest](config.ValidationMiddleware, h.ResendVerificationEmail))
}
if config.AuthRateLimit != nil {
rateLimited := config.AuthRateLimit(r)
rateLimited.Post("/auth/register", h.Register)
rateLimited.Post("/auth/login", h.Login)
rateLimited.Post("/auth/forgot-password", h.RequestPasswordReset)
rateLimited.Post("/auth/reset-password", h.ResetPassword)
rateLimited.Post("/auth/account/confirm", h.ConfirmAccountDeletion)
rateLimited.Post("/auth/register", WithValidation[dto.RegisterRequest](config.ValidationMiddleware, h.Register))
rateLimited.Post("/auth/login", WithValidation[dto.LoginRequest](config.ValidationMiddleware, h.Login))
rateLimited.Post("/auth/forgot-password", WithValidation[dto.ForgotPasswordRequest](config.ValidationMiddleware, h.RequestPasswordReset))
rateLimited.Post("/auth/reset-password", WithValidation[dto.ResetPasswordRequest](config.ValidationMiddleware, h.ResetPassword))
rateLimited.Post("/auth/account/confirm", WithValidation[dto.ConfirmAccountDeletionRequest](config.ValidationMiddleware, h.ConfirmAccountDeletion))
} else {
r.Post("/auth/register", h.Register)
r.Post("/auth/login", h.Login)
r.Post("/auth/forgot-password", h.RequestPasswordReset)
r.Post("/auth/reset-password", h.ResetPassword)
r.Post("/auth/account/confirm", h.ConfirmAccountDeletion)
r.Post("/auth/register", WithValidation[dto.RegisterRequest](config.ValidationMiddleware, h.Register))
r.Post("/auth/login", WithValidation[dto.LoginRequest](config.ValidationMiddleware, h.Login))
r.Post("/auth/forgot-password", WithValidation[dto.ForgotPasswordRequest](config.ValidationMiddleware, h.RequestPasswordReset))
r.Post("/auth/reset-password", WithValidation[dto.ResetPasswordRequest](config.ValidationMiddleware, h.ResetPassword))
r.Post("/auth/account/confirm", WithValidation[dto.ConfirmAccountDeletionRequest](config.ValidationMiddleware, h.ConfirmAccountDeletion))
}
protected := r
@@ -816,10 +724,10 @@ func (h *AuthHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
protected.Get("/auth/me", h.Me)
protected.Post("/auth/logout", h.Logout)
protected.Post("/auth/revoke", h.RevokeToken)
protected.Post("/auth/revoke", WithValidation[dto.RevokeTokenRequest](config.ValidationMiddleware, h.RevokeToken))
protected.Post("/auth/revoke-all", h.RevokeAllTokens)
protected.Put("/auth/email", h.UpdateEmail)
protected.Put("/auth/username", h.UpdateUsername)
protected.Put("/auth/password", h.UpdatePassword)
protected.Put("/auth/email", WithValidation[dto.UpdateEmailRequest](config.ValidationMiddleware, h.UpdateEmail))
protected.Put("/auth/username", WithValidation[dto.UpdateUsernameRequest](config.ValidationMiddleware, h.UpdateUsername))
protected.Put("/auth/password", WithValidation[dto.UpdatePasswordRequest](config.ValidationMiddleware, h.UpdatePassword))
protected.Delete("/auth/account", h.DeleteAccount)
}

View File

@@ -252,8 +252,8 @@ func TestAuthHandlerLoginSuccess(t *testing.T) {
}
handler := newAuthHandler(repo)
body := bytes.NewBufferString(`{"username":"user","password":"Password123!"}`)
request := httptest.NewRequest(http.MethodPost, "/api/auth/login", body)
bodyStr := `{"username":"user","password":"Password123!"}`
request := createLoginRequest(bodyStr)
recorder := httptest.NewRecorder()
handler.Login(recorder, request)
@@ -274,17 +274,17 @@ func TestAuthHandlerLoginErrors(t *testing.T) {
handler := newAuthHandler(&testutils.UserRepositoryStub{})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString("invalid"))
request := createLoginRequest("invalid")
handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":" ","password":""}`))
request = createLoginRequest(`{"username":" ","password":""}`)
handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"user","password":"WrongPass123!"}`))
request = createLoginRequest(`{"username":"user","password":"WrongPass123!"}`)
handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
@@ -294,7 +294,7 @@ func TestAuthHandlerLoginErrors(t *testing.T) {
}}
handler = newAuthHandler(repo)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"user","password":"Password123!"}`))
request = createLoginRequest(`{"username":"user","password":"Password123!"}`)
handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusForbidden)
@@ -304,7 +304,7 @@ func TestAuthHandlerLoginErrors(t *testing.T) {
handler = newAuthHandler(repo)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"user","password":"Password123!"}`))
request = createLoginRequest(`{"username":"user","password":"Password123!"}`)
handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusInternalServerError)
@@ -330,8 +330,7 @@ func TestAuthHandlerRegisterSuccess(t *testing.T) {
return nil
}})
body := bytes.NewBufferString(`{"username":"newuser","email":"new@example.com","password":"Password123!"}`)
request := httptest.NewRequest(http.MethodPost, "/api/auth/register", body)
request := createRegisterRequest(`{"username":"newuser","email":"new@example.com","password":"Password123!"}`)
recorder := httptest.NewRecorder()
handler.Register(recorder, request)
@@ -354,12 +353,12 @@ func TestAuthHandlerRegisterErrors(t *testing.T) {
handler := newAuthHandler(repo)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString("invalid"))
request := createRegisterRequest("invalid")
handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString(`{"username":"","email":"","password":""}`))
request = createRegisterRequest(`{"username":"","email":"","password":""}`)
handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
@@ -368,7 +367,7 @@ func TestAuthHandlerRegisterErrors(t *testing.T) {
}}
handler = newAuthHandler(repo)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString(`{"username":"new","email":"taken@example.com","password":"Password123!"}`))
request = createRegisterRequest(`{"username":"new","email":"taken@example.com","password":"Password123!"}`)
handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
@@ -382,7 +381,7 @@ func TestAuthHandlerRegisterErrors(t *testing.T) {
}
handler = newAuthHandler(repo)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString(`{"username":"another","email":"taken@example.com","password":"Password123!"}`))
request = createRegisterRequest(`{"username":"another","email":"taken@example.com","password":"Password123!"}`)
handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
}
@@ -477,7 +476,7 @@ func TestAuthHandlerRequestPasswordReset(t *testing.T) {
}})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`{"username_or_email":"user@example.com"}`))
request := createForgotPasswordRequest(`{"username_or_email":"user@example.com"}`)
handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
@@ -495,19 +494,19 @@ func TestAuthHandlerRequestPasswordReset(t *testing.T) {
}})
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`{"username_or_email":"user"}`))
request = createForgotPasswordRequest(`{"username_or_email":"user"}`)
handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`{"username_or_email":""}`))
request = createForgotPasswordRequest(`{"username_or_email":""}`)
handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`invalid json`))
request = createForgotPasswordRequest(`invalid json`)
handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
@@ -518,25 +517,25 @@ func TestAuthHandlerResetPassword(t *testing.T) {
handler := newAuthHandler(repo)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"new_password":"NewPassword123!"}`))
request := createResetPasswordRequest(`{"new_password":"NewPassword123!"}`)
handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"token":"valid_token"}`))
request = createResetPasswordRequest(`{"token":"valid_token"}`)
handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"token":"valid_token","new_password":"short"}`))
request = createResetPasswordRequest(`{"token":"valid_token","new_password":"short"}`)
handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`invalid json`))
request = createResetPasswordRequest(`invalid json`)
handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
@@ -602,7 +601,7 @@ func TestAuthHandlerResetPasswordServiceOutcomes(t *testing.T) {
handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"token":"abc","new_password":"Password123!"}`))
request := createResetPasswordRequest(`{"token":"abc","new_password":"Password123!"}`)
recorder := httptest.NewRecorder()
handler.ResetPassword(recorder, request)
@@ -664,7 +663,7 @@ func TestAuthHandlerUpdateEmail(t *testing.T) {
userID: 1,
mockSetup: func(repo *testutils.UserRepositoryStub) {},
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body",
expectedError: "Invalid request",
},
{
name: "empty email",
@@ -702,7 +701,7 @@ func TestAuthHandlerUpdateEmail(t *testing.T) {
handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPut, "/api/auth/email", bytes.NewBufferString(tt.requestBody))
request := createUpdateEmailRequest(tt.requestBody)
if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx)
@@ -789,7 +788,7 @@ func TestAuthHandlerUpdateUsername(t *testing.T) {
handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPut, "/api/auth/username", bytes.NewBufferString(tt.requestBody))
request := createUpdateUsernameRequest(tt.requestBody)
if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx)
@@ -886,7 +885,7 @@ func TestAuthHandlerUpdatePassword(t *testing.T) {
tt.mockSetup(repo)
handler := newAuthHandler(repo)
request := httptest.NewRequest(http.MethodPut, "/api/auth/password", bytes.NewBufferString(tt.requestBody))
request := createUpdatePasswordRequest(tt.requestBody)
if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx)
@@ -984,8 +983,7 @@ func TestAuthHandlerDeleteAccount(t *testing.T) {
func TestAuthHandlerResendVerificationEmail(t *testing.T) {
makeRequest := func(body string, setup func(*mockAuthService)) (*httptest.ResponseRecorder, AuthResponse) {
request := httptest.NewRequest(http.MethodPost, "/api/auth/resend-verification", bytes.NewBufferString(body))
request = request.WithContext(context.Background())
request := createResendVerificationRequest(body)
repo := &testutils.UserRepositoryStub{}
mockService := &mockAuthService{}
@@ -1014,7 +1012,7 @@ func TestAuthHandlerResendVerificationEmail(t *testing.T) {
name: "invalid json",
body: "not-json",
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body",
expectedError: "Invalid request",
},
{
name: "missing email",
@@ -1139,7 +1137,7 @@ func TestAuthHandlerConfirmAccountDeletion(t *testing.T) {
name: "invalid json",
body: "not-json",
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body",
expectedError: "Invalid request",
},
{
name: "missing token",
@@ -1209,7 +1207,7 @@ func TestAuthHandlerConfirmAccountDeletion(t *testing.T) {
handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPost, "/api/auth/account/confirm", bytes.NewBufferString(tt.body))
request := createConfirmAccountDeletionRequest(tt.body)
recorder := httptest.NewRecorder()
handler.ConfirmAccountDeletion(recorder, request)
@@ -1338,9 +1336,7 @@ func TestAuthHandler_ConcurrentAccess(t *testing.T) {
for i := 0; i < concurrency; i++ {
go func() {
body := bytes.NewBufferString(`{"username":"testuser","password":"Password123!"}`)
req := httptest.NewRequest("POST", "/api/auth/login", body)
req.Header.Set("Content-Type", "application/json")
req := createLoginRequest(`{"username":"testuser","password":"Password123!"}`)
w := httptest.NewRecorder()
handler.Login(w, req)
@@ -1370,8 +1366,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
}, nil
}
body := bytes.NewBufferString(`{"refresh_token":"valid_refresh_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":"valid_refresh_token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1381,8 +1376,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
})
t.Run("Invalid_Request_Body", func(t *testing.T) {
body := bytes.NewBufferString(`invalid json`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`invalid json`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1392,8 +1386,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
})
t.Run("Missing_Refresh_Token", func(t *testing.T) {
body := bytes.NewBufferString(`{"refresh_token":""}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":""}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1407,8 +1400,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, services.ErrRefreshTokenExpired
}
body := bytes.NewBufferString(`{"refresh_token":"expired_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":"expired_token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1422,8 +1414,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, services.ErrRefreshTokenInvalid
}
body := bytes.NewBufferString(`{"refresh_token":"invalid_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":"invalid_token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1437,8 +1428,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, services.ErrAccountLocked
}
body := bytes.NewBufferString(`{"refresh_token":"locked_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":"locked_token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1452,8 +1442,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, fmt.Errorf("internal error")
}
body := bytes.NewBufferString(`{"refresh_token":"error_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":"error_token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1473,8 +1462,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
return nil
}
body := bytes.NewBufferString(`{"refresh_token":"token_to_revoke"}`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req := createRevokeTokenRequest(`{"refresh_token":"token_to_revoke"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1484,8 +1472,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
})
t.Run("Invalid_Request_Body", func(t *testing.T) {
body := bytes.NewBufferString(`invalid json`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req := createRevokeTokenRequest(`invalid json`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1495,8 +1482,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
})
t.Run("Missing_Refresh_Token", func(t *testing.T) {
body := bytes.NewBufferString(`{"refresh_token":""}`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req := createRevokeTokenRequest(`{"refresh_token":""}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1510,8 +1496,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
return fmt.Errorf("revoke failed")
}
body := bytes.NewBufferString(`{"refresh_token":"token"}`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req := createRevokeTokenRequest(`{"refresh_token":"token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"net/url"
"reflect"
"strconv"
"strings"
"time"
@@ -290,3 +291,24 @@ func HandleServiceError(w http.ResponseWriter, err error, defaultMsg string, def
SendErrorResponse(w, defaultMsg, defaultCode)
return false
}
func GetValidatedDTO[T any](r *http.Request) (*T, bool) {
dtoVal := middleware.GetValidatedDTOFromContext(r.Context())
if dtoVal == nil {
return nil, false
}
dto, ok := dtoVal.(*T)
return dto, ok
}
func WithValidation[T any](validationMiddleware func(http.Handler) http.Handler, handler http.HandlerFunc) http.HandlerFunc {
if validationMiddleware == nil {
return handler
}
var zero T
dtoType := reflect.TypeOf(zero)
return func(w http.ResponseWriter, r *http.Request) {
ctx := middleware.SetDTOTypeInContext(r.Context(), dtoType)
validationMiddleware(handler).ServeHTTP(w, r.WithContext(ctx))
}
}

View File

@@ -1,6 +1,7 @@
package handlers
import (
"bytes"
"context"
"encoding/json"
"errors"
@@ -11,6 +12,7 @@ import (
"testing"
"goyco/internal/database"
"goyco/internal/dto"
"goyco/internal/middleware"
"goyco/internal/services"
"goyco/internal/testutils"
@@ -721,6 +723,74 @@ func TestDecodeJSONRequest(t *testing.T) {
}
}
func createRequestWithDTO[T any](method, url string, body []byte) *http.Request {
r := httptest.NewRequest(method, url, bytes.NewReader(body))
var dto T
if len(body) > 0 {
if err := json.NewDecoder(bytes.NewReader(body)).Decode(&dto); err != nil {
return r
}
}
ctx := middleware.SetValidatedDTOInContext(r.Context(), &dto)
return r.WithContext(ctx)
}
func createLoginRequest(body string) *http.Request {
return createRequestWithDTO[dto.LoginRequest](http.MethodPost, "/api/auth/login", []byte(body))
}
func createRegisterRequest(body string) *http.Request {
return createRequestWithDTO[dto.RegisterRequest](http.MethodPost, "/api/auth/register", []byte(body))
}
func createResendVerificationRequest(body string) *http.Request {
return createRequestWithDTO[dto.ResendVerificationRequest](http.MethodPost, "/api/auth/resend-verification", []byte(body))
}
func createForgotPasswordRequest(body string) *http.Request {
return createRequestWithDTO[dto.ForgotPasswordRequest](http.MethodPost, "/api/auth/forgot-password", []byte(body))
}
func createResetPasswordRequest(body string) *http.Request {
return createRequestWithDTO[dto.ResetPasswordRequest](http.MethodPost, "/api/auth/reset-password", []byte(body))
}
func createUpdateEmailRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdateEmailRequest](http.MethodPut, "/api/auth/email", []byte(body))
}
func createUpdateUsernameRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdateUsernameRequest](http.MethodPut, "/api/auth/username", []byte(body))
}
func createUpdatePasswordRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdatePasswordRequest](http.MethodPut, "/api/auth/password", []byte(body))
}
func createConfirmAccountDeletionRequest(body string) *http.Request {
return createRequestWithDTO[dto.ConfirmAccountDeletionRequest](http.MethodPost, "/api/auth/account/confirm", []byte(body))
}
func createRefreshTokenRequest(body string) *http.Request {
return createRequestWithDTO[dto.RefreshTokenRequest](http.MethodPost, "/api/auth/refresh", []byte(body))
}
func createRevokeTokenRequest(body string) *http.Request {
return createRequestWithDTO[dto.RevokeTokenRequest](http.MethodPost, "/api/auth/revoke", []byte(body))
}
func createCreatePostRequest(body string) *http.Request {
return createRequestWithDTO[dto.CreatePostRequest](http.MethodPost, "/api/posts", []byte(body))
}
func createUpdatePostRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdatePostRequest](http.MethodPut, "/api/posts/1", []byte(body))
}
func createVoteRequest(body string) *http.Request {
return createRequestWithDTO[dto.CastVoteRequest](http.MethodPost, "/api/posts/1/vote", []byte(body))
}
func TestParsePagination(t *testing.T) {
tests := []struct {
name string

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"html/template"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
@@ -877,7 +878,8 @@ func (h *PageHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
func (h *PageHandler) Settings(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r)
if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther)
redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return
}
@@ -897,7 +899,8 @@ func (h *PageHandler) Settings(w http.ResponseWriter, r *http.Request) {
func (h *PageHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r)
if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther)
redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return
}
@@ -960,13 +963,15 @@ func (h *PageHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
}
h.clearAuthCookie(w, r)
http.Redirect(w, r, "/login?flash=Email updated. Check your inbox to confirm the new address. You will need to sign in again after verification.", http.StatusSeeOther)
redirectURL := "/login?flash=" + url.QueryEscape("Email updated. Check your inbox to confirm the new address. You will need to sign in again after verification.")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
func (h *PageHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r)
if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther)
redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return
}
@@ -1022,13 +1027,15 @@ func (h *PageHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
return
}
http.Redirect(w, r, "/settings?flash=Username updated successfully.", http.StatusSeeOther)
redirectURL := "/settings?flash=" + url.QueryEscape("Username updated successfully.")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
func (h *PageHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r)
if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther)
redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return
}
@@ -1140,13 +1147,15 @@ func (h *PageHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) {
return
}
http.Redirect(w, r, "/settings?flash=Password updated successfully.", http.StatusSeeOther)
redirectURL := "/settings?flash=" + url.QueryEscape("Password updated successfully.")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
func (h *PageHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r)
if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther)
redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return
}
@@ -1204,7 +1213,8 @@ func (h *PageHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
return
}
http.Redirect(w, r, "/settings?flash=Check your inbox for a confirmation link to finish deleting your account.", http.StatusSeeOther)
redirectURL := "/settings?flash=" + url.QueryEscape("Check your inbox for a confirmation link to finish deleting your account.")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
func (h *PageHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Request) {
@@ -1328,7 +1338,8 @@ func (h *PageHandler) clearAuthCookie(w http.ResponseWriter, r *http.Request) {
func (h *PageHandler) Vote(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r)
if user == nil {
http.Redirect(w, r, "/login?flash=Please sign in to vote", http.StatusSeeOther)
redirectURL := "/login?flash=" + url.QueryEscape("Please sign in to vote")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return
}

View File

@@ -36,11 +36,6 @@ func NewPostHandler(postRepo repositories.PostRepository, titleFetcher services.
type PostResponse = CommonResponse
type UpdatePostRequest struct {
Title string `json:"title"`
Content string `json:"content"`
}
// @Summary Get posts
// @Description Get a list of posts with pagination. Posts include vote statistics (up_votes, down_votes, score) and current user's vote status.
// @Tags posts
@@ -111,7 +106,7 @@ func (h *PostHandler) GetPost(w http.ResponseWriter, r *http.Request) {
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body CreatePostRequest true "Post data"
// @Param request body dto.CreatePostRequest true "Post data"
// @Success 201 {object} PostResponse
// @Failure 400 {object} PostResponse "Invalid request data or validation failed"
// @Failure 401 {object} PostResponse "Authentication required"
@@ -120,32 +115,9 @@ func (h *PostHandler) GetPost(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} PostResponse "Internal server error"
// @Router /api/posts [post]
func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
var req struct {
Title string `json:"title"`
URL string `json:"url"`
Content string `json:"content"`
}
if !DecodeJSONRequest(w, r, &req) {
return
}
req.Title = security.SanitizeInput(req.Title)
req.URL = security.SanitizeURL(req.URL)
req.Content = security.SanitizePostContent(req.Content)
if req.URL == "" {
SendErrorResponse(w, "URL is required", http.StatusBadRequest)
return
}
if len(req.Title) > 200 {
SendErrorResponse(w, "Title must be no more than 200 characters", http.StatusBadRequest)
return
}
if len(req.Content) > 10000 {
SendErrorResponse(w, "Content must be no more than 10,000 characters", http.StatusBadRequest)
req, ok := GetValidatedDTO[dto.CreatePostRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -154,13 +126,20 @@ func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
return
}
title := req.Title
title := security.SanitizeInput(req.Title)
url := security.SanitizeURL(req.URL)
content := security.SanitizePostContent(req.Content)
if url == "" {
SendErrorResponse(w, "Invalid URL", http.StatusBadRequest)
return
}
if title == "" && h.titleFetcher != nil {
titleCtx, cancel := context.WithTimeout(r.Context(), 7*time.Second)
defer cancel()
fetchedTitle, err := h.titleFetcher.FetchTitle(titleCtx, req.URL)
fetchedTitle, err := h.titleFetcher.FetchTitle(titleCtx, url)
if err != nil {
switch {
case errors.Is(err, services.ErrUnsupportedScheme):
@@ -186,10 +165,20 @@ func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
return
}
if len(title) > 200 {
SendErrorResponse(w, "Title must be at most 200 characters", http.StatusBadRequest)
return
}
if len(content) > 10000 {
SendErrorResponse(w, "Content must be at most 10000 characters", http.StatusBadRequest)
return
}
post := &database.Post{
Title: title,
URL: req.URL,
Content: req.Content,
URL: url,
Content: content,
AuthorID: &userID,
}
@@ -257,7 +246,7 @@ func (h *PostHandler) SearchPosts(w http.ResponseWriter, r *http.Request) {
// @Produce json
// @Security BearerAuth
// @Param id path int true "Post ID"
// @Param request body UpdatePostRequest true "Post update data"
// @Param request body dto.UpdatePostRequest true "Post update data"
// @Success 200 {object} PostResponse "Post updated successfully"
// @Failure 400 {object} PostResponse "Invalid request data or validation failed"
// @Failure 401 {object} PostResponse "Authentication required"
@@ -286,40 +275,27 @@ func (h *PostHandler) UpdatePost(w http.ResponseWriter, r *http.Request) {
return
}
var req struct {
Title string `json:"title"`
Content string `json:"content"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.UpdatePostRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
req.Title = security.SanitizeInput(req.Title)
req.Content = security.SanitizePostContent(req.Content)
title := security.SanitizeInput(req.Title)
content := security.SanitizePostContent(req.Content)
if len(req.Title) > 200 {
SendErrorResponse(w, "Title must be no more than 200 characters", http.StatusBadRequest)
return
}
if len(req.Content) > 10000 {
SendErrorResponse(w, "Content must be no more than 10,000 characters", http.StatusBadRequest)
return
}
if err := validation.ValidateTitle(req.Title); err != nil {
if err := validation.ValidateTitle(title); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
if err := validation.ValidateContent(req.Content); err != nil {
if err := validation.ValidateContent(content); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
post.Title = req.Title
post.Content = req.Content
post.Title = title
post.Content = content
if err := h.postRepo.Update(post); err != nil {
SendErrorResponse(w, "Failed to update post", http.StatusInternalServerError)
@@ -458,7 +434,7 @@ func (h *PostHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
if config.GeneralRateLimit != nil {
protected = config.GeneralRateLimit(protected)
}
protected.Post("/posts", h.CreatePost)
protected.Put("/posts/{id}", h.UpdatePost)
protected.Post("/posts", WithValidation[dto.CreatePostRequest](config.ValidationMiddleware, h.CreatePost))
protected.Put("/posts/{id}", WithValidation[dto.UpdatePostRequest](config.ValidationMiddleware, h.UpdatePost))
protected.Delete("/posts/{id}", h.DeletePost)
}

View File

@@ -69,9 +69,8 @@ func TestPostHandlerCreatePostWithTitleFetcher(t *testing.T) {
handler := NewPostHandler(repo, titleFetcher, nil)
request := httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"url":"https://example.com","content":"Test content"}`))
request := createCreatePostRequest(`{"url":"https://example.com","content":"Test content"}`)
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.CreatePost(recorder, request)
@@ -171,7 +170,7 @@ func TestPostHandlerUpdatePostUnauthorized(t *testing.T) {
handler := NewPostHandler(repo, nil, nil)
request := httptest.NewRequest(http.MethodPut, "/api/posts/1", bytes.NewBufferString(`{"title":"Updated Title","content":"Updated content"}`))
request := createUpdatePostRequest(`{"title":"Updated Title","content":"Updated content"}`)
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
request.Header.Set("Content-Type", "application/json")
@@ -278,8 +277,7 @@ func TestPostHandlerCreatePostSuccess(t *testing.T) {
handler := NewPostHandler(repo, fetcher, nil)
body := bytes.NewBufferString(`{"title":" ","url":"https://example.com","content":"Go"}`)
request := httptest.NewRequest(http.MethodPost, "/api/posts", body)
request := createCreatePostRequest(`{"title":" ","url":"https://example.com","content":"Go"}`)
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(42))
request = request.WithContext(ctx)
@@ -297,7 +295,7 @@ func TestPostHandlerCreatePostValidation(t *testing.T) {
handler := NewPostHandler(testutils.NewPostRepositoryStub(), &testutils.TitleFetcherStub{}, nil)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"title":"","url":"","content":""}`))
request := createCreatePostRequest(`{"title":"","url":"","content":""}`)
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
handler.CreatePost(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
@@ -305,14 +303,14 @@ func TestPostHandlerCreatePostValidation(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`invalid json`))
request = createCreatePostRequest(`invalid json`)
handler.CreatePost(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid JSON, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"title":"ok","url":"https://example.com"}`))
request = createCreatePostRequest(`{"title":"ok","url":"https://example.com"}`)
handler.CreatePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
}
@@ -336,8 +334,7 @@ func TestPostHandlerCreatePostTitleFetcherErrors(t *testing.T) {
return "", tc.err
}}
handler := NewPostHandler(repo, fetcher, nil)
body := bytes.NewBufferString(`{"title":" ","url":"https://example.com"}`)
request := httptest.NewRequest(http.MethodPost, "/api/posts", body)
request := createCreatePostRequest(`{"title":" ","url":"https://example.com"}`)
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder()
@@ -495,7 +492,7 @@ func TestPostHandlerUpdatePost(t *testing.T) {
}
handler := NewPostHandler(repo, &testutils.TitleFetcherStub{}, nil)
request := httptest.NewRequest(http.MethodPut, "/api/posts/"+tt.postID, bytes.NewBufferString(tt.requestBody))
request := createUpdatePostRequest(tt.requestBody)
if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx)
@@ -669,6 +666,9 @@ func (e *errorVoteRepository) Delete(uint) error { ret
func (e *errorVoteRepository) Count() (int64, error) { return 0, nil }
func (e *errorVoteRepository) CountByPostID(uint) (int64, error) { return 0, nil }
func (e *errorVoteRepository) CountByUserID(uint) (int64, error) { return 0, nil }
func (e *errorVoteRepository) GetVoteCountsByPostID(uint) (int, int, error) {
return 0, 0, errors.New("database error")
}
func (e *errorVoteRepository) WithTx(*gorm.DB) repositories.VoteRepository { return e }
func TestPostHandler_EdgeCases(t *testing.T) {

View File

@@ -18,4 +18,5 @@ type RouteModuleConfig struct {
AuthRateLimit func(chi.Router) chi.Router
CSRFMiddleware func(http.Handler) http.Handler
AuthMiddleware func(http.Handler) http.Handler
ValidationMiddleware func(http.Handler) http.Handler
}

View File

@@ -24,10 +24,6 @@ func TestPostHandler_XSSProtection_Comprehensive(t *testing.T) {
t.Run("XSS_"+payload[:minLen(20, len(payload))], func(t *testing.T) {
repo := &testutils.PostRepositoryStub{
CreateFn: func(post *database.Post) error {
sanitizedTitle := security.SanitizeInput(payload)
if post.Title != sanitizedTitle {
t.Errorf("Expected sanitized title, got %q", post.Title)
}
return nil
},
}
@@ -41,14 +37,46 @@ func TestPostHandler_XSSProtection_Comprehensive(t *testing.T) {
}
body, _ := json.Marshal(postData)
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
request := createCreatePostRequest(string(body))
request.Header.Set("Content-Type", "application/json")
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder()
handler.CreatePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusCreated)
if recorder.Code != http.StatusCreated {
t.Errorf("Expected status %d, got %d. Body: %s", http.StatusCreated, recorder.Code, recorder.Body.String())
return
}
var response CommonResponse
if err := json.NewDecoder(recorder.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if !response.Success {
t.Errorf("Expected successful response, got error: %s", response.Error)
return
}
dataMap, ok := response.Data.(map[string]any)
if !ok {
t.Fatalf("Expected data to be a map, got %T", response.Data)
}
title, ok := dataMap["title"].(string)
if !ok {
t.Fatalf("Expected title to be a string, got %T", dataMap["title"])
}
expectedSanitized := security.SanitizeInput(payload)
if title != expectedSanitized {
t.Errorf("Expected sanitized title %q, got %q", expectedSanitized, title)
}
if title == payload {
t.Errorf("Title was not sanitized - original payload %q matches response %q", payload, title)
}
})
}
}
@@ -123,7 +151,7 @@ func TestPostHandler_InputValidation(t *testing.T) {
}
body, _ := json.Marshal(postData)
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
request := createCreatePostRequest(string(body))
request.Header.Set("Content-Type", "application/json")
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder()
@@ -230,7 +258,7 @@ func TestAuthHandler_PasswordValidation(t *testing.T) {
}
body, _ := json.Marshal(registerData)
request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
request := createRegisterRequest(string(body))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
@@ -290,7 +318,7 @@ func TestAuthHandler_UsernameSanitization(t *testing.T) {
}
body, _ := json.Marshal(registerData)
request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
request := createRegisterRequest(string(body))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()

View File

@@ -91,7 +91,7 @@ func (h *UserHandler) GetUser(w http.ResponseWriter, r *http.Request) {
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body RegisterRequest true "User data"
// @Param request body dto.RegisterRequest true "User data"
// @Success 201 {object} UserResponse "User created successfully"
// @Failure 400 {object} UserResponse "Invalid request data or validation failed"
// @Failure 401 {object} UserResponse "Authentication required"
@@ -99,13 +99,9 @@ func (h *UserHandler) GetUser(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} UserResponse "Internal server error"
// @Router /api/users [post]
func (h *UserHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
var req struct {
Username string `json:"username"`
Email string `json:"email"`
Password string `json:"password"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.RegisterRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -189,7 +185,7 @@ func (h *UserHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
}
protected.Get("/users", h.GetUsers)
protected.Post("/users", h.CreateUser)
protected.Post("/users", WithValidation[dto.RegisterRequest](config.ValidationMiddleware, h.CreateUser))
protected.Get("/users/{id}", h.GetUser)
protected.Get("/users/{id}/posts", h.GetUserPosts)
}

View File

@@ -1,7 +1,6 @@
package handlers
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
@@ -103,7 +102,7 @@ func TestUserHandlerCreateUser(t *testing.T) {
return nil
}})
request := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"user","email":"user@example.com","password":"Password123!"}`))
request := createRegisterRequest(`{"username":"user","email":"user@example.com","password":"Password123!"}`)
recorder := httptest.NewRecorder()
handler.CreateUser(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusCreated)
@@ -126,14 +125,14 @@ func TestUserHandlerCreateUser(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString("invalid"))
request = createRegisterRequest("invalid")
handler.CreateUser(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid json, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"","email":"","password":""}`))
request = createRegisterRequest(`{"username":"","email":"","password":""}`)
handler.CreateUser(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for missing fields, got %d", recorder.Result().StatusCode)
@@ -144,7 +143,7 @@ func TestUserHandlerCreateUser(t *testing.T) {
}
handler = newUserHandler(repo)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"user","email":"user@example.com","password":"Password123!"}`))
request = createRegisterRequest(`{"username":"user","email":"user@example.com","password":"Password123!"}`)
handler.CreateUser(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
}
@@ -350,7 +349,7 @@ func TestUserHandler_PasswordValidation(t *testing.T) {
handler := NewUserHandler(repo, authService)
requestBody := fmt.Sprintf(`{"username":"testuser","email":"test@example.com","password":"%s"}`, tt.password)
request := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(requestBody))
request := createRegisterRequest(requestBody)
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()

View File

@@ -4,6 +4,7 @@ import (
"net/http"
"goyco/internal/database"
"goyco/internal/dto"
"goyco/internal/services"
"github.com/go-chi/chi/v5"
@@ -39,11 +40,6 @@ func NewVoteHandler(voteService *services.VoteService) *VoteHandler {
}
}
// @Description Vote request with type field. All votes are handled the same way.
type VoteRequest struct {
Type string `json:"type" example:"up" enums:"up,down,none" description:"Vote type: 'up' for upvote, 'down' for downvote, 'none' to remove vote"`
}
type VoteResponse = CommonResponse
// @Summary Cast a vote on a post
@@ -62,7 +58,7 @@ type VoteResponse = CommonResponse
// @Produce json
// @Security BearerAuth
// @Param id path int true "Post ID"
// @Param request body VoteRequest true "Vote data (type: 'up', 'down', or 'none' to remove)"
// @Param request body dto.CastVoteRequest true "Vote data (type: 'up', 'down', or 'none' to remove)"
// @Success 200 {object} VoteResponse "Vote cast successfully with updated post statistics"
// @Failure 401 {object} VoteResponse "Authentication required"
// @Failure 400 {object} VoteResponse "Invalid request data or vote type"
@@ -82,8 +78,9 @@ func (h *VoteHandler) CastVote(w http.ResponseWriter, r *http.Request) {
return
}
var req VoteRequest
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.CastVoteRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -286,7 +283,7 @@ func (h *VoteHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
protected = config.GeneralRateLimit(protected)
}
protected.Post("/posts/{id}/vote", h.CastVote)
protected.Post("/posts/{id}/vote", WithValidation[dto.CastVoteRequest](config.ValidationMiddleware, h.CastVote))
protected.Delete("/posts/{id}/vote", h.RemoveVote)
protected.Get("/posts/{id}/vote", h.GetUserVote)
protected.Get("/posts/{id}/votes", h.GetPostVotes)

View File

@@ -1,7 +1,6 @@
package handlers
import (
"bytes"
"context"
"encoding/json"
"fmt"
@@ -59,13 +58,13 @@ func TestVoteHandlerCastVote(t *testing.T) {
handler := newVoteHandlerWithRepos()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
handler.CastVote(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/abc/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "abc"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -73,7 +72,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`invalid`))
request = createVoteRequest(`invalid`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -83,7 +82,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"maybe"}`))
request = createVoteRequest(`{"type":"maybe"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -93,7 +92,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -101,7 +100,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx)
@@ -111,7 +110,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"none"}`))
request = createVoteRequest(`{"type":"none"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3))
request = request.WithContext(ctx)
@@ -125,7 +124,7 @@ func TestVoteHandlerCastVotePostNotFound(t *testing.T) {
handler, _, posts := newVoteHandlerWithReposRefs()
delete(posts, 1)
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -164,7 +163,7 @@ func TestVoteHandlerRemoveVote(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -202,7 +201,7 @@ func TestVoteHandlerRemoveVotePostNotFound(t *testing.T) {
func TestVoteHandlerRemoveVoteUnexpectedError(t *testing.T) {
handler, voteRepo, _ := newVoteHandlerWithReposRefs()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -257,7 +256,7 @@ func TestVoteHandlerGetUserVote(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -301,7 +300,7 @@ func TestVoteHandlerGetPostVotes(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -311,7 +310,7 @@ func TestVoteHandlerGetPostVotes(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx)
@@ -345,7 +344,7 @@ func TestVoteFlowRegression(t *testing.T) {
postID := "1"
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx)
@@ -363,7 +362,7 @@ func TestVoteFlowRegression(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx)
@@ -373,7 +372,7 @@ func TestVoteFlowRegression(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"none"}`))
request = createVoteRequest(`{"type":"none"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx)
@@ -404,7 +403,7 @@ func TestVoteFlowRegression(t *testing.T) {
postID := "1"
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -414,7 +413,7 @@ func TestVoteFlowRegression(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx)
@@ -424,7 +423,7 @@ func TestVoteFlowRegression(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3))
request = request.WithContext(ctx)
@@ -452,7 +451,7 @@ func TestVoteFlowRegression(t *testing.T) {
t.Run("ErrorHandlingEdgeCases", func(t *testing.T) {
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(``))
request := createVoteRequest(``)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -460,7 +459,7 @@ func TestVoteFlowRegression(t *testing.T) {
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{}`))
request = createVoteRequest(`{}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -470,7 +469,7 @@ func TestVoteFlowRegression(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"invalid"}`))
request = createVoteRequest(`{"type":"invalid"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)

View File

@@ -53,25 +53,25 @@ func TestIntegration_Caching(t *testing.T) {
router := ctx.Router
t.Run("Cache_Hit_On_Repeated_Requests", func(t *testing.T) {
req1 := httptest.NewRequest("GET", "/api/posts", nil)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
firstRequest := httptest.NewRequest("GET", "/api/posts", nil)
firstRecorder := httptest.NewRecorder()
router.ServeHTTP(firstRecorder, firstRequest)
time.Sleep(10 * time.Millisecond)
req2 := httptest.NewRequest("GET", "/api/posts", nil)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
secondRequest := httptest.NewRequest("GET", "/api/posts", nil)
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondRequest)
if rec1.Code != rec2.Code {
if firstRecorder.Code != secondRecorder.Code {
t.Error("Cached responses should have same status code")
}
if rec1.Body.String() != rec2.Body.String() {
if firstRecorder.Body.String() != secondRecorder.Body.String() {
t.Error("Cached responses should have same body")
}
if rec2.Header().Get("X-Cache") != "HIT" {
if secondRecorder.Header().Get("X-Cache") != "HIT" {
t.Log("Cache may not be enabled for this path or response may not be cacheable")
}
})
@@ -80,9 +80,9 @@ func TestIntegration_Caching(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createUserWithCleanup(t, ctx, "cache_post_user", "cache_post@example.com")
req1 := httptest.NewRequest("GET", "/api/posts", nil)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
firstRequest := httptest.NewRequest("GET", "/api/posts", nil)
firstRecorder := httptest.NewRecorder()
router.ServeHTTP(firstRecorder, firstRequest)
time.Sleep(10 * time.Millisecond)
@@ -92,12 +92,12 @@ func TestIntegration_Caching(t *testing.T) {
"content": "Test content",
}
body, _ := json.Marshal(postBody)
req2 := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req2.Header.Set("Content-Type", "application/json")
req2.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
secondRequest := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
secondRequest.Header.Set("Content-Type", "application/json")
secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondRequest)
time.Sleep(10 * time.Millisecond)
@@ -105,17 +105,17 @@ func TestIntegration_Caching(t *testing.T) {
rec3 := httptest.NewRecorder()
router.ServeHTTP(rec3, req3)
if rec1.Body.String() == rec3.Body.String() && rec1.Code == http.StatusOK && rec3.Code == http.StatusOK {
if firstRecorder.Body.String() == rec3.Body.String() && firstRecorder.Code == http.StatusOK && rec3.Code == http.StatusOK {
t.Log("Cache invalidation may not be working or cache may not be enabled")
}
})
t.Run("Cache_Headers_Present", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("GET", "/api/posts", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if rec.Header().Get("Cache-Control") == "" && rec.Header().Get("X-Cache") == "" {
if recorder.Header().Get("Cache-Control") == "" && recorder.Header().Get("X-Cache") == "" {
t.Log("Cache headers may not be present for all responses")
}
})
@@ -126,18 +126,18 @@ func TestIntegration_Caching(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Cache Delete Post", "https://example.com/cache-delete")
req1 := httptest.NewRequest("GET", "/api/posts", nil)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
firstRequest := httptest.NewRequest("GET", "/api/posts", nil)
firstRecorder := httptest.NewRecorder()
router.ServeHTTP(firstRecorder, firstRequest)
time.Sleep(10 * time.Millisecond)
req2 := httptest.NewRequest("DELETE", "/api/posts/"+fmt.Sprintf("%d", post.ID), nil)
req2.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID)
req2 = testutils.WithURLParams(req2, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
secondRequest := httptest.NewRequest("DELETE", "/api/posts/"+fmt.Sprintf("%d", post.ID), nil)
secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
secondRequest = testutils.WithURLParams(secondRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondRequest)
time.Sleep(10 * time.Millisecond)
@@ -145,7 +145,7 @@ func TestIntegration_Caching(t *testing.T) {
rec3 := httptest.NewRecorder()
router.ServeHTTP(rec3, req3)
if rec1.Body.String() == rec3.Body.String() && rec1.Code == http.StatusOK && rec3.Code == http.StatusOK {
if firstRecorder.Body.String() == rec3.Body.String() && firstRecorder.Code == http.StatusOK && rec3.Code == http.StatusOK {
t.Log("Cache invalidation may not be working or cache may not be enabled")
}
})

View File

@@ -1,15 +1,14 @@
package integration
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"goyco/internal/middleware"
"goyco/internal/services"
"goyco/internal/testutils"
)
@@ -20,17 +19,8 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "logout_user", "logout@example.com")
reqBody := map[string]string{}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/logout", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
request := makePostRequest(t, ctx.Router, "/api/auth/logout", map[string]any{}, user, nil)
assertStatus(t, request, http.StatusOK)
})
t.Run("Auth_Revoke_Token_Endpoint", func(t *testing.T) {
@@ -42,52 +32,23 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
t.Fatalf("Failed to login: %v", err)
}
reqBody := map[string]string{
"refresh_token": loginResult.RefreshToken,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/revoke", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
request := makePostRequest(t, ctx.Router, "/api/auth/revoke", map[string]any{"refresh_token": loginResult.RefreshToken}, user, nil)
assertStatus(t, request, http.StatusOK)
})
t.Run("Auth_Revoke_All_Tokens_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "revoke_all_user", "revoke_all@example.com")
reqBody := map[string]string{}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/revoke-all", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
request := makePostRequest(t, ctx.Router, "/api/auth/revoke-all", map[string]any{}, user, nil)
assertStatus(t, request, http.StatusOK)
})
t.Run("Auth_Resend_Verification_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{
"email": "resend@example.com",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/resend-verification", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatusRange(t, rec, http.StatusOK, http.StatusNotFound)
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/resend-verification", map[string]any{"email": "resend@example.com"})
assertStatusRange(t, request, http.StatusOK, http.StatusNotFound)
})
t.Run("Auth_Confirm_Email_Endpoint", func(t *testing.T) {
@@ -99,109 +60,66 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
token = "test-token"
}
req := httptest.NewRequest("GET", "/api/auth/confirm?token="+url.QueryEscape(token), nil)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatusRange(t, rec, http.StatusOK, http.StatusBadRequest)
request := makeGetRequest(t, ctx.Router, "/api/auth/confirm?token="+url.QueryEscape(token))
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Auth_Update_Email_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "update_email_api_user", "update_email_api@example.com")
reqBody := map[string]string{
"email": "newemail@example.com",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/auth/email", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makePutRequest(t, ctx.Router, "/api/auth/email", map[string]any{"email": "newemail@example.com"}, user, nil)
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if email, ok := data["email"].(string); ok && email != "newemail@example.com" {
t.Errorf("Expected email to be updated, got %s", email)
}
}
}
})
t.Run("Auth_Update_Username_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "update_username_api_user", "update_username_api@example.com")
reqBody := map[string]string{
"username": "new_username",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/auth/username", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makePutRequest(t, ctx.Router, "/api/auth/username", map[string]any{"username": "new_username"}, user, nil)
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if username, ok := data["username"].(string); ok && username != "new_username" {
t.Errorf("Expected username to be updated, got %s", username)
}
}
}
})
t.Run("Users_List_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "users_list_user", "users_list@example.com")
req := httptest.NewRequest("GET", "/api/users", nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, ctx.Router, "/api/users", user, nil)
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if _, exists := data["users"]; !exists {
t.Error("Expected users in response")
}
}
}
})
t.Run("Users_Get_By_ID_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "users_get_user", "users_get@example.com")
req := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d", user.User.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if userData, ok := data["user"].(map[string]any); ok {
if id, ok := userData["id"].(float64); ok && uint(id) != user.User.ID {
t.Errorf("Expected user ID %d, got %.0f", user.User.ID, id)
}
}
}
}
})
t.Run("Users_Get_Posts_Endpoint", func(t *testing.T) {
@@ -210,17 +128,10 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "User Posts Test", "https://example.com/user-posts")
req := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d/posts", user.User.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d/posts", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) == 0 {
t.Error("Expected at least one post in response")
@@ -229,35 +140,24 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
t.Error("Expected posts array in response")
}
}
}
})
t.Run("Users_Create_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "users_create_admin", "users_create_admin@example.com")
reqBody := map[string]string{
request := makePostRequest(t, ctx.Router, "/api/users", map[string]any{
"username": "created_user",
"email": "created@example.com",
"password": "SecurePass123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/users", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
}, user, nil)
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusCreated)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
response := assertJSONResponse(t, request, http.StatusCreated)
if data, ok := getDataFromResponse(response); ok {
if _, exists := data["user"]; !exists {
t.Error("Expected user in response")
}
}
}
})
t.Run("Posts_Update_Endpoint", func(t *testing.T) {
@@ -266,30 +166,19 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Update Test Post", "https://example.com/update-test")
reqBody := map[string]string{
request := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{
"title": "Updated Title",
"content": "Updated content",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if postData, ok := data["post"].(map[string]any); ok {
if title, ok := postData["title"].(string); ok && title != "Updated Title" {
t.Errorf("Expected title 'Updated Title', got '%s'", title)
}
}
}
}
})
t.Run("Posts_Delete_Endpoint", func(t *testing.T) {
@@ -298,20 +187,11 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Delete Test Post", "https://example.com/delete-test")
req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
request := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, request, http.StatusOK)
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil)
getRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRec, getReq)
assertStatus(t, getRec, http.StatusNotFound)
getRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID))
assertStatus(t, getRequest, http.StatusNotFound)
})
t.Run("Votes_Get_All_Endpoint", func(t *testing.T) {
@@ -319,28 +199,11 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "votes_get_all_user", "votes_get_all@example.com")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Votes Test Post", "https://example.com/votes-test")
makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/votes", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteBody := map[string]string{"type": "up"}
voteBodyBytes, _ := json.Marshal(voteBody)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBodyBytes))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
req := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if votes, ok := data["votes"].([]any); ok {
if len(votes) == 0 {
t.Error("Expected at least one vote in response")
@@ -349,7 +212,6 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
t.Error("Expected votes array in response")
}
}
}
})
t.Run("Votes_Remove_Endpoint", func(t *testing.T) {
@@ -358,49 +220,461 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Remove Test", "https://example.com/vote-remove")
voteBody := map[string]string{"type": "up"}
voteBodyBytes, _ := json.Marshal(voteBody)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBodyBytes))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
request := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, request, http.StatusOK)
})
t.Run("API_Info_Endpoint", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, ctx.Router, "/api")
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if _, exists := data["endpoints"]; !exists {
t.Error("Expected endpoints in API info")
}
}
}
})
t.Run("Swagger_Documentation_Endpoint", func(t *testing.T) {
req := httptest.NewRequest("GET", "/swagger/index.html", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, ctx.Router, "/swagger/index.html")
assertStatusRange(t, request, http.StatusOK, http.StatusNotFound)
})
ctx.Router.ServeHTTP(rec, req)
t.Run("Search_Endpoint_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "search_edge"), uniqueTestEmail(t, "search_edge"))
assertStatusRange(t, rec, http.StatusOK, http.StatusNotFound)
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Searchable Post One", "https://example.com/one")
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Searchable Post Two", "https://example.com/two")
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Different Content", "https://example.com/three")
t.Run("Empty_Search_Results", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q=nonexistentterm12345")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) != 0 {
t.Errorf("Expected empty search results, got %d posts", len(posts))
}
}
if count, ok := data["count"].(float64); ok && count != 0 {
t.Errorf("Expected count 0, got %.0f", count)
}
}
})
t.Run("Search_With_Pagination", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q=Searchable&limit=1&offset=0")
response := assertJSONResponse(t, request, http.StatusOK)
var firstPostID any
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) > 1 {
t.Errorf("Expected at most 1 post with limit=1, got %d", len(posts))
}
if len(posts) > 0 {
if post, ok := posts[0].(map[string]any); ok {
firstPostID = post["id"]
}
}
}
if limit, ok := data["limit"].(float64); ok && limit != 1 {
t.Errorf("Expected limit 1 in response, got %.0f", limit)
}
if offset, ok := data["offset"].(float64); ok && offset != 0 {
t.Errorf("Expected offset 0 in response, got %.0f", offset)
}
}
secondRequest := makeGetRequest(t, ctx.Router, "/api/posts/search?q=Searchable&limit=1&offset=1")
secondResponse := assertJSONResponse(t, secondRequest, http.StatusOK)
if data, ok := getDataFromResponse(secondResponse); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) > 1 {
t.Errorf("Expected at most 1 post with limit=1 and offset=1, got %d", len(posts))
}
if len(posts) > 0 && firstPostID != nil {
if post, ok := posts[0].(map[string]any); ok {
if post["id"] == firstPostID {
t.Error("Expected different post with offset=1, got same post as offset=0")
}
}
}
}
}
})
t.Run("Search_With_Special_Characters", func(t *testing.T) {
specialQueries := []string{
"Searchable%20Post",
"Searchable'Post",
"Searchable\"Post",
"Searchable;Post",
"Searchable--Post",
}
for _, query := range specialQueries {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q="+url.QueryEscape(query))
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
}
})
t.Run("Search_Empty_Query", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q=")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) != 0 {
t.Errorf("Expected empty results for empty query, got %d posts", len(posts))
}
}
if count, ok := data["count"].(float64); ok && count != 0 {
t.Errorf("Expected count 0 for empty query, got %.0f", count)
}
}
})
t.Run("Search_With_Very_Long_Query", func(t *testing.T) {
longQuery := strings.Repeat("a", 1000)
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q="+url.QueryEscape(longQuery))
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Search_Case_Insensitive", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q=SEARCHABLE")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) == 0 {
t.Error("Expected case-insensitive search to find posts")
}
}
}
})
})
t.Run("Title_Fetch_Endpoint_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
t.Run("Missing_URL_Parameter", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/title")
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Empty_URL_Parameter", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url=")
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Invalid_URL_Format", func(t *testing.T) {
invalidURLs := []string{
"not-a-url",
"://invalid",
"http://",
"https://",
}
for _, invalidURL := range invalidURLs {
ctx.Suite.TitleFetcher.SetError(services.ErrUnsupportedScheme)
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url="+url.QueryEscape(invalidURL))
assertErrorResponse(t, request, http.StatusBadRequest)
}
})
t.Run("Unsupported_URL_Schemes", func(t *testing.T) {
unsupportedSchemes := []string{
"ftp://example.com",
"file:///etc/passwd",
"javascript:alert(1)",
"data:text/html,<script>alert(1)</script>",
}
for _, schemeURL := range unsupportedSchemes {
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url="+url.QueryEscape(schemeURL))
assertErrorResponse(t, request, http.StatusBadRequest)
}
})
t.Run("SSRF_Protection_Localhost", func(t *testing.T) {
ssrfURLs := []string{
"http://localhost",
"http://127.0.0.1",
"http://127.0.0.1:8080",
"http://[::1]",
"http://0.0.0.0",
}
for _, ssrfURL := range ssrfURLs {
ctx.Suite.TitleFetcher.SetError(services.ErrSSRFBlocked)
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url="+url.QueryEscape(ssrfURL))
assertStatusRange(t, request, http.StatusBadRequest, http.StatusBadGateway)
}
})
t.Run("SSRF_Protection_Private_IPs", func(t *testing.T) {
privateIPs := []string{
"http://192.168.1.1",
"http://10.0.0.1",
"http://172.16.0.1",
}
for _, privateIP := range privateIPs {
ctx.Suite.TitleFetcher.SetError(services.ErrSSRFBlocked)
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url="+url.QueryEscape(privateIP))
assertStatusRange(t, request, http.StatusBadRequest, http.StatusBadGateway)
}
})
t.Run("Title_Fetch_Error_Handling", func(t *testing.T) {
ctx.Suite.TitleFetcher.SetError(services.ErrTitleNotFound)
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url=https://example.com/notitle")
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Valid_URL_Success", func(t *testing.T) {
ctx.Suite.TitleFetcher.SetTitle("Valid Title")
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url=https://example.com/valid")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if title, ok := data["title"].(string); ok {
if title != "Valid Title" {
t.Errorf("Expected title 'Valid Title', got '%s'", title)
}
} else {
t.Error("Expected title in response data")
}
}
})
})
t.Run("Get_User_Vote_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "vote_edge"), uniqueTestEmail(t, "vote_edge"))
secondUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "vote_edge2"), uniqueTestEmail(t, "vote_edge2"))
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Edge Test Post", "https://example.com/vote-edge")
t.Run("Get_Vote_When_User_Has_Voted", func(t *testing.T) {
voteRequest := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, voteRequest, http.StatusOK)
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if hasVote, ok := data["has_vote"].(bool); !ok || !hasVote {
t.Error("Expected has_vote to be true when user has voted")
}
if vote, ok := data["vote"]; !ok || vote == nil {
t.Error("Expected vote object when user has voted")
}
}
})
t.Run("Get_Vote_When_User_Has_Not_Voted", func(t *testing.T) {
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if hasVote, ok := data["has_vote"].(bool); ok {
if hasVote {
t.Error("Expected has_vote to be false when user has not voted")
}
} else {
t.Error("Expected has_vote field in response")
}
}
})
t.Run("Get_Vote_Invalid_Post_ID", func(t *testing.T) {
request := makeAuthenticatedGetRequest(t, ctx.Router, "/api/posts/999999/vote", user, map[string]string{"id": "999999"})
if request.Code != http.StatusOK && request.Code != http.StatusNotFound {
t.Errorf("Expected status 200 or 404 for invalid post ID, got %d", request.Code)
}
})
t.Run("Get_Vote_Unauthenticated", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID))
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Get_Vote_Response_Structure", func(t *testing.T) {
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
response := assertJSONResponse(t, request, http.StatusOK)
if success, ok := response["success"].(bool); !ok || !success {
t.Error("Expected success field to be true")
}
if data, ok := getDataFromResponse(response); ok {
if _, exists := data["has_vote"]; !exists {
t.Error("Expected has_vote field in response data")
}
if _, exists := data["is_anonymous"]; !exists {
t.Error("Expected is_anonymous field in response data")
}
} else {
t.Error("Expected data field in response")
}
})
})
t.Run("Refresh_Token_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
refreshUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "refresh_edge"), uniqueTestEmail(t, "refresh_edge"))
t.Run("Refresh_With_Expired_Token", func(t *testing.T) {
loginResult, err := ctx.AuthService.Login(refreshUser.User.Username, "SecurePass123!")
if err != nil {
t.Fatalf("Failed to login: %v", err)
}
refreshToken, err := ctx.Suite.RefreshTokenRepo.GetByTokenHash(testutils.HashVerificationToken(loginResult.RefreshToken))
if err != nil {
t.Fatalf("Failed to get refresh token: %v", err)
}
refreshToken.ExpiresAt = time.Now().Add(-1 * time.Hour)
if err := ctx.Suite.DB.Model(refreshToken).Update("expires_at", refreshToken.ExpiresAt).Error; err != nil {
t.Fatalf("Failed to expire refresh token: %v", err)
}
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": loginResult.RefreshToken})
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Refresh_With_Revoked_Token", func(t *testing.T) {
loginResult, err := ctx.AuthService.Login(refreshUser.User.Username, "SecurePass123!")
if err != nil {
t.Fatalf("Failed to login: %v", err)
}
if err := ctx.AuthService.RevokeRefreshToken(loginResult.RefreshToken); err != nil {
t.Fatalf("Failed to revoke refresh token: %v", err)
}
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": loginResult.RefreshToken})
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Refresh_With_Empty_Token", func(t *testing.T) {
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": ""})
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Refresh_With_Missing_Token_Field", func(t *testing.T) {
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{})
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Refresh_Token_Rotation", func(t *testing.T) {
loginResult, err := ctx.AuthService.Login(refreshUser.User.Username, "SecurePass123!")
if err != nil {
t.Fatalf("Failed to login: %v", err)
}
originalRefreshToken := loginResult.RefreshToken
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": originalRefreshToken})
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if newAccessToken, ok := data["access_token"].(string); ok {
if newAccessToken == "" {
t.Error("Expected new access token in refresh response")
}
if newRefreshToken, ok := data["refresh_token"].(string); ok {
if newRefreshToken != "" && newRefreshToken == originalRefreshToken {
t.Log("Refresh token rotation may not be implemented (same token returned)")
}
}
}
}
})
t.Run("Refresh_After_Account_Lock", func(t *testing.T) {
lockedUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "refresh_lock"), uniqueTestEmail(t, "refresh_lock"))
loginResult, err := ctx.AuthService.Login(lockedUser.User.Username, "SecurePass123!")
if err != nil {
t.Fatalf("Failed to login: %v", err)
}
lockedUser.User.Locked = true
if err := ctx.Suite.UserRepo.Update(lockedUser.User); err != nil {
t.Fatalf("Failed to lock user: %v", err)
}
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": loginResult.RefreshToken})
assertStatusRange(t, request, http.StatusUnauthorized, http.StatusForbidden)
})
t.Run("Refresh_With_Invalid_Token_Format", func(t *testing.T) {
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": "invalid-token-format-12345"})
assertErrorResponse(t, request, http.StatusUnauthorized)
})
})
t.Run("Pagination_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
paginationUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "pagination_edge"), uniqueTestEmail(t, "pagination_edge"))
for i := 0; i < 5; i++ {
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, paginationUser.User.ID, fmt.Sprintf("Pagination Post %d", i), fmt.Sprintf("https://example.com/pag%d", i))
}
t.Run("Negative_Limit", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts?limit=-1")
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Negative_Offset", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts?offset=-1")
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Very_Large_Limit", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts?limit=10000")
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Very_Large_Offset", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts?offset=10000")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) > 0 {
t.Logf("Large offset returned %d posts (may be expected)", len(posts))
}
}
}
})
t.Run("Invalid_Pagination_Parameters", func(t *testing.T) {
invalidParams := []string{
"limit=abc",
"offset=xyz",
"limit=",
"offset=",
}
for _, param := range invalidParams {
request := makeGetRequest(t, ctx.Router, "/api/posts?"+param)
assertStatus(t, request, http.StatusOK)
}
})
})
}

View File

@@ -1,17 +1,12 @@
package integration
import (
"bytes"
"compress/gzip"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"goyco/internal/middleware"
"goyco/internal/testutils"
)
func TestIntegration_Compression(t *testing.T) {
@@ -19,16 +14,16 @@ func TestIntegration_Compression(t *testing.T) {
router := ctx.Router
t.Run("Response_Compression_Gzip", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil)
req.Header.Set("Accept-Encoding", "gzip")
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/api/posts", nil)
request.Header.Set("Accept-Encoding", "gzip")
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
contentEncoding := rec.Header().Get("Content-Encoding")
contentEncoding := recorder.Header().Get("Content-Encoding")
if contentEncoding != "" && strings.Contains(contentEncoding, "gzip") {
assertHeaderContains(t, rec, "Content-Encoding", "gzip")
reader, err := gzip.NewReader(rec.Body)
assertHeaderContains(t, recorder, "Content-Encoding", "gzip")
reader, err := gzip.NewReader(recorder.Body)
if err != nil {
t.Fatalf("Failed to create gzip reader: %v", err)
}
@@ -48,14 +43,14 @@ func TestIntegration_Compression(t *testing.T) {
})
t.Run("Compression_Headers_Present", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil)
req.Header.Set("Accept-Encoding", "gzip, deflate")
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/api/posts", nil)
request.Header.Set("Accept-Encoding", "gzip, deflate")
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Header().Get("Vary") != "" {
assertHeaderContains(t, rec, "Vary", "Accept-Encoding")
if recorder.Header().Get("Vary") != "" {
assertHeaderContains(t, recorder, "Vary", "Accept-Encoding")
} else {
t.Log("Vary header may not always be present")
}
@@ -67,25 +62,19 @@ func TestIntegration_StaticFiles(t *testing.T) {
router := ctx.Router
t.Run("Robots_Txt_Served", func(t *testing.T) {
req := httptest.NewRequest("GET", "/robots.txt", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/robots.txt")
router.ServeHTTP(rec, req)
assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
if !strings.Contains(rec.Body.String(), "User-agent") {
if !strings.Contains(request.Body.String(), "User-agent") {
t.Error("Expected robots.txt content")
}
})
t.Run("Static_Files_Security_Headers", func(t *testing.T) {
req := httptest.NewRequest("GET", "/robots.txt", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/robots.txt")
router.ServeHTTP(rec, req)
if rec.Header().Get("X-Content-Type-Options") == "" {
if request.Header().Get("X-Content-Type-Options") == "" {
t.Log("Security headers may not be applied to all static files")
}
})
@@ -101,32 +90,22 @@ func TestIntegration_URLMetadata(t *testing.T) {
ctx.Suite.TitleFetcher.SetTitle("Fetched Title")
postBody := map[string]string{
postBody := map[string]any{
"title": "Test Post",
"url": "https://example.com/metadata-test",
"content": "Test content",
}
body, _ := json.Marshal(postBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makePostRequest(t, router, "/api/posts", postBody, user, nil)
router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusCreated)
assertStatus(t, request, http.StatusCreated)
})
t.Run("URL_Metadata_Endpoint", func(t *testing.T) {
ctx.Suite.TitleFetcher.SetTitle("Endpoint Title")
req := httptest.NewRequest("GET", "/api/posts/title?url=https://example.com/test", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/api/posts/title?url=https://example.com/test")
router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, request, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["title"]; !exists {

View File

@@ -1,14 +1,10 @@
package integration
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"goyco/internal/middleware"
"goyco/internal/testutils"
)
@@ -22,33 +18,19 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, owner.User.ID, "Owner Post", "https://example.com/owner")
updateBody := map[string]string{
request := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{
"title": "Updated Title",
"content": "Updated content",
}
body, _ := json.Marshal(updateBody)
}, otherUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+otherUser.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, otherUser.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
assertErrorResponse(t, request, http.StatusForbidden)
ctx.Router.ServeHTTP(rec, req)
request = makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{
"title": "Updated Title",
"content": "Updated content",
}, owner, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertErrorResponse(t, rec, http.StatusForbidden)
req = httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+owner.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, owner.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
})
t.Run("Post_Delete_Authorization", func(t *testing.T) {
@@ -58,47 +40,27 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, owner.User.ID, "Delete Post", "https://example.com/delete")
req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil)
req.Header.Set("Authorization", "Bearer "+otherUser.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, otherUser.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
request := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), otherUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, request, http.StatusForbidden)
assertErrorResponse(t, rec, http.StatusForbidden)
request = makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), owner, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req = httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil)
req.Header.Set("Authorization", "Bearer "+owner.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, owner.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
})
t.Run("User_Profile_Access_Authorization", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user1 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "profile_user1", "profile_user1@example.com")
user2 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "profile_user2", "profile_user2@example.com")
firstUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "profile_user1", "profile_user1@example.com")
secondUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "profile_user2", "profile_user2@example.com")
req := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d", user1.User.ID), nil)
req.Header.Set("Authorization", "Bearer "+user2.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user2.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", user1.User.ID)})
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d", firstUser.User.ID), secondUser, map[string]string{"id": fmt.Sprintf("%d", firstUser.User.ID)})
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if userData, ok := data["user"].(map[string]any); ok {
if id, ok := userData["id"].(float64); ok && uint(id) != user1.User.ID {
t.Errorf("Expected user ID %d, got %.0f", user1.User.ID, id)
}
if id, ok := userData["id"].(float64); ok && uint(id) != firstUser.User.ID {
t.Errorf("Expected user ID %d, got %.0f", firstUser.User.ID, id)
}
}
}
@@ -109,24 +71,13 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "settings_auth_user", "settings_auth@example.com")
otherUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "settings_auth_other", "settings_auth_other@example.com")
updateBody := map[string]string{
"email": "newemail@example.com",
}
body, _ := json.Marshal(updateBody)
request := makePutRequest(t, ctx.Router, "/api/auth/email", map[string]any{"email": "newemail@example.com"}, otherUser, nil)
req := httptest.NewRequest("PUT", "/api/auth/email", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+otherUser.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, otherUser.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, request, http.StatusOK)
if response == nil {
return
}
if data, ok := response["data"].(map[string]any); ok {
if data, ok := getDataFromResponse(response); ok {
if userData, ok := data["user"].(map[string]any); ok {
if email, ok := userData["email"].(string); ok && email == "newemail@example.com" {
if id, ok := userData["id"].(float64); ok && uint(id) != otherUser.User.ID {
@@ -136,20 +87,9 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
}
}
updateBody2 := map[string]string{
"email": "anothernewemail@example.com",
}
body2, _ := json.Marshal(updateBody2)
request = makePutRequest(t, ctx.Router, "/api/auth/email", map[string]any{"email": "anothernewemail@example.com"}, user, nil)
req = httptest.NewRequest("PUT", "/api/auth/email", bytes.NewBuffer(body2))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
})
t.Run("Vote_Authorization", func(t *testing.T) {
@@ -159,73 +99,42 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, postOwner.User.ID, "Vote Auth Post", "https://example.com/vote-auth")
voteBody := map[string]string{"type": "up"}
body, _ := json.Marshal(voteBody)
request := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
assertStatus(t, request, http.StatusOK)
ctx.Router.ServeHTTP(rec, req)
request = makePostRequestWithJSON(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"})
assertStatus(t, rec, http.StatusOK)
req = httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Protected_Endpoint_Without_Auth", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("{}")))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, ctx.Router, "/api/posts", map[string]any{})
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Protected_Endpoint_With_Invalid_Token", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("{}")))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer invalid-token")
rec := httptest.NewRecorder()
request := makeRequest(t, ctx.Router, "POST", "/api/posts", []byte("{}"), map[string]string{"Content-Type": "application/json", "Authorization": "Bearer invalid-token"})
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("User_List_Authorization", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "list_auth_user", "list_auth@example.com")
req := httptest.NewRequest("GET", "/api/users", nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, ctx.Router, "/api/users", user, nil)
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
request = makeGetRequest(t, ctx.Router, "/api/users")
req = httptest.NewRequest("GET", "/api/users", nil)
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Refresh_Token_Authorization", func(t *testing.T) {
@@ -237,18 +146,9 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
t.Fatalf("Failed to login: %v", err)
}
refreshBody := map[string]string{
"refresh_token": loginResult.RefreshToken,
}
body, _ := json.Marshal(refreshBody)
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": loginResult.RefreshToken})
req := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, request, http.StatusOK)
if response == nil {
return
}
@@ -260,17 +160,8 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
t.Error("Expected data field in refresh response")
}
refreshBody = map[string]string{
"refresh_token": "invalid-refresh-token",
}
body, _ = json.Marshal(refreshBody)
request = makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": "invalid-refresh-token"})
req = httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, request, http.StatusUnauthorized)
})
}

View File

@@ -14,166 +14,137 @@ func TestIntegration_CSRF_Protection(t *testing.T) {
ctx := setupPageHandlerTestContext(t)
router := ctx.Router
t.Run("CSRF_Blocks_Form_Without_Token", func(t *testing.T) {
reqBody := url.Values{}
reqBody.Set("username", "testuser")
reqBody.Set("email", "test@example.com")
reqBody.Set("password", "SecurePass123!")
getCSRFToken := func(t *testing.T, path string, cookies ...*http.Cookie) *http.Cookie {
t.Helper()
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d. Body: %s", rec.Code, rec.Body.String())
request := httptest.NewRequest("GET", path, nil)
for _, c := range cookies {
request.AddCookie(c)
}
if !strings.Contains(rec.Body.String(), "Invalid CSRF token") {
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
for _, cookie := range recorder.Result().Cookies() {
if cookie.Name == "csrf_token" {
return cookie
}
}
t.Fatalf("Expected CSRF cookie to be set for %s", path)
return nil
}
t.Run("CSRF_Blocks_Form_Without_Token", func(t *testing.T) {
requestBody := url.Values{}
requestBody.Set("username", "testuser")
requestBody.Set("email", "test@example.com")
requestBody.Set("password", "SecurePass123!")
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d. Body: %s", recorder.Code, recorder.Body.String())
}
if !strings.Contains(recorder.Body.String(), "Invalid CSRF token") {
t.Error("Expected CSRF error message")
}
})
t.Run("CSRF_Allows_Form_With_Valid_Token", func(t *testing.T) {
getReq := httptest.NewRequest("GET", "/register", nil)
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
csrfCookie := getCSRFToken(t, "/register")
cookies := getRec.Result().Cookies()
var csrfCookie *http.Cookie
for _, cookie := range cookies {
if cookie.Name == "csrf_token" {
csrfCookie = cookie
break
}
}
requestBody := url.Values{}
requestBody.Set("username", "csrf_user")
requestBody.Set("email", "csrf@example.com")
requestBody.Set("password", "SecurePass123!")
requestBody.Set("csrf_token", csrfCookie.Value)
if csrfCookie == nil {
t.Fatal("Expected CSRF cookie to be set")
}
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
csrfToken := csrfCookie.Value
router.ServeHTTP(recorder, request)
reqBody := url.Values{}
reqBody.Set("username", "csrf_user")
reqBody.Set("email", "csrf@example.com")
reqBody.Set("password", "SecurePass123!")
reqBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code == http.StatusForbidden {
if recorder.Code == http.StatusForbidden {
t.Error("Expected form submission with valid CSRF token to succeed")
}
})
t.Run("CSRF_Allows_API_Requests", func(t *testing.T) {
reqBody := map[string]string{
requestBody := map[string]string{
"username": "api_user",
"email": "api@example.com",
"password": "SecurePass123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
body, _ := json.Marshal(requestBody)
request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Code == http.StatusForbidden {
if recorder.Code == http.StatusForbidden {
t.Error("Expected API requests to bypass CSRF protection")
}
})
t.Run("CSRF_Blocks_Mismatched_Token", func(t *testing.T) {
getReq := httptest.NewRequest("GET", "/register", nil)
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
csrfCookie := getCSRFToken(t, "/register")
cookies := getRec.Result().Cookies()
var csrfCookie *http.Cookie
for _, cookie := range cookies {
if cookie.Name == "csrf_token" {
csrfCookie = cookie
break
requestBody := url.Values{}
requestBody.Set("username", "mismatch_user")
requestBody.Set("email", "mismatch@example.com")
requestBody.Set("password", "SecurePass123!")
requestBody.Set("csrf_token", "wrong-token")
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d. Body: %s", recorder.Code, recorder.Body.String())
}
}
if csrfCookie == nil {
t.Fatal("Expected CSRF cookie to be set")
}
reqBody := url.Values{}
reqBody.Set("username", "mismatch_user")
reqBody.Set("email", "mismatch@example.com")
reqBody.Set("password", "SecurePass123!")
reqBody.Set("csrf_token", "wrong-token")
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d. Body: %s", rec.Code, rec.Body.String())
}
if !strings.Contains(rec.Body.String(), "Invalid CSRF token") {
if !strings.Contains(recorder.Body.String(), "Invalid CSRF token") {
t.Error("Expected CSRF error message")
}
})
t.Run("CSRF_Allows_GET_Requests", func(t *testing.T) {
req := httptest.NewRequest("GET", "/register", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/register", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Code == http.StatusForbidden {
if recorder.Code == http.StatusForbidden {
t.Error("Expected GET requests to bypass CSRF protection")
}
})
t.Run("CSRF_Token_In_Header", func(t *testing.T) {
getReq := httptest.NewRequest("GET", "/register", nil)
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
csrfCookie := getCSRFToken(t, "/register")
cookies := getRec.Result().Cookies()
var csrfCookie *http.Cookie
for _, cookie := range cookies {
if cookie.Name == "csrf_token" {
csrfCookie = cookie
break
}
}
requestBody := url.Values{}
requestBody.Set("username", "header_user")
requestBody.Set("email", "header@example.com")
requestBody.Set("password", "SecurePass123!")
if csrfCookie == nil {
t.Fatal("Expected CSRF cookie to be set")
}
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.Header.Set("X-CSRF-Token", csrfCookie.Value)
request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
csrfToken := csrfCookie.Value
router.ServeHTTP(recorder, request)
reqBody := url.Values{}
reqBody.Set("username", "header_user")
reqBody.Set("email", "header@example.com")
reqBody.Set("password", "SecurePass123!")
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("X-CSRF-Token", csrfToken)
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code == http.StatusForbidden {
if recorder.Code == http.StatusForbidden {
t.Error("Expected CSRF token in header to be accepted")
}
})
@@ -182,41 +153,24 @@ func TestIntegration_CSRF_Protection(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createUserWithCleanup(t, ctx, "csrf_form_user", "csrf_form@example.com")
getReq := httptest.NewRequest("GET", "/posts/new", nil)
getReq.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
authCookie := &http.Cookie{Name: "auth_token", Value: user.Token}
csrfCookie := getCSRFToken(t, "/posts/new", authCookie)
cookies := getRec.Result().Cookies()
var csrfCookie *http.Cookie
for _, cookie := range cookies {
if cookie.Name == "csrf_token" {
csrfCookie = cookie
break
}
}
requestBody := url.Values{}
requestBody.Set("title", "CSRF Test Post")
requestBody.Set("url", "https://example.com/csrf-test")
requestBody.Set("content", "Test content")
requestBody.Set("csrf_token", csrfCookie.Value)
if csrfCookie == nil {
t.Fatal("Expected CSRF cookie to be set")
}
request := httptest.NewRequest("POST", "/posts", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(authCookie)
request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
csrfToken := csrfCookie.Value
router.ServeHTTP(recorder, request)
reqBody := url.Values{}
reqBody.Set("title", "CSRF Test Post")
reqBody.Set("url", "https://example.com/csrf-test")
reqBody.Set("content", "Test content")
reqBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/posts", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code == http.StatusForbidden {
if recorder.Code == http.StatusForbidden {
t.Error("Expected post creation with valid CSRF token to succeed")
}
})

View File

@@ -19,27 +19,18 @@ func TestIntegration_DataConsistency(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "consistency_user", "consistency@example.com")
postBody := map[string]string{
request := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "Consistency Test Post",
"url": "https://example.com/consistency",
"content": "Test content",
}
body, _ := json.Marshal(postBody)
}, user, nil)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
createResponse := assertJSONResponse(t, rec, http.StatusCreated)
createResponse := assertJSONResponse(t, request, http.StatusCreated)
if createResponse == nil {
return
}
postData, ok := createResponse["data"].(map[string]any)
postData, ok := getDataFromResponse(createResponse)
if !ok {
t.Fatal("Response missing data")
}
@@ -53,16 +44,14 @@ func TestIntegration_DataConsistency(t *testing.T) {
createdURL := postData["url"]
createdContent := postData["content"]
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%.0f", postID), nil)
getRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRec, getReq)
getRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%.0f", postID))
getResponse := assertJSONResponse(t, getRec, http.StatusOK)
getResponse := assertJSONResponse(t, getRequest, http.StatusOK)
if getResponse == nil {
return
}
getPostData, ok := getResponse["data"].(map[string]any)
getPostData, ok := getDataFromResponse(getResponse)
if !ok {
t.Fatal("Get response missing data")
}
@@ -96,32 +85,17 @@ func TestIntegration_DataConsistency(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Consistency Post", "https://example.com/vote-consistency")
voteBody := map[string]string{"type": "up"}
body, _ := json.Marshal(voteBody)
voteRequest := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, voteRequest, http.StatusOK)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
getVotesRequest := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/votes", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, voteRec, http.StatusOK)
getVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesReq.Header.Set("Authorization", "Bearer "+user.Token)
getVotesReq = testutils.WithUserContext(getVotesReq, middleware.UserIDKey, user.User.ID)
getVotesReq = testutils.WithURLParams(getVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRec, getVotesReq)
votesResponse := assertJSONResponse(t, getVotesRec, http.StatusOK)
votesResponse := assertJSONResponse(t, getVotesRequest, http.StatusOK)
if votesResponse == nil {
return
}
votesData, ok := votesResponse["data"].(map[string]any)
votesData, ok := getDataFromResponse(votesResponse)
if !ok {
t.Fatal("Votes response missing data")
}
@@ -172,32 +146,21 @@ func TestIntegration_DataConsistency(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Original Title", "https://example.com/original")
updateBody := map[string]string{
updateRequest := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{
"title": "Updated Title",
"content": "Updated content",
}
body, _ := json.Marshal(updateBody)
}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
updateReq := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body))
updateReq.Header.Set("Content-Type", "application/json")
updateReq.Header.Set("Authorization", "Bearer "+user.Token)
updateReq = testutils.WithUserContext(updateReq, middleware.UserIDKey, user.User.ID)
updateReq = testutils.WithURLParams(updateReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
updateRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(updateRec, updateReq)
assertStatus(t, updateRequest, http.StatusOK)
assertStatus(t, updateRec, http.StatusOK)
getRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID))
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil)
getRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRec, getReq)
getResponse := assertJSONResponse(t, getRec, http.StatusOK)
getResponse := assertJSONResponse(t, getRequest, http.StatusOK)
if getResponse == nil {
return
}
getPostData, ok := getResponse["data"].(map[string]any)
getPostData, ok := getDataFromResponse(getResponse)
if !ok {
t.Fatal("Get response missing data")
}
@@ -215,18 +178,12 @@ func TestIntegration_DataConsistency(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "user_posts_consistency", "user_posts_consistency@example.com")
post1 := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Post 1", "https://example.com/post1")
post2 := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Post 2", "https://example.com/post2")
firstPost := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Post 1", "https://example.com/post1")
secondPost := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Post 2", "https://example.com/post2")
req := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d/posts", user.User.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d/posts", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, request, http.StatusOK)
if response == nil {
return
}
@@ -245,26 +202,26 @@ func TestIntegration_DataConsistency(t *testing.T) {
t.Errorf("Expected at least 2 posts, got %d", len(posts))
}
foundPost1 := false
foundPost2 := false
foundFirstPost := false
foundSecondPost := false
for _, post := range posts {
if postMap, ok := post.(map[string]any); ok {
if postID, ok := postMap["id"].(float64); ok {
if uint(postID) == post1.ID {
foundPost1 = true
if uint(postID) == firstPost.ID {
foundFirstPost = true
}
if uint(postID) == post2.ID {
foundPost2 = true
if uint(postID) == secondPost.ID {
foundSecondPost = true
}
}
}
}
if !foundPost1 {
if !foundFirstPost {
t.Error("Post 1 not found in user posts")
}
if !foundPost2 {
if !foundSecondPost {
t.Error("Post 2 not found in user posts")
}
})
@@ -275,20 +232,20 @@ func TestIntegration_DataConsistency(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Delete Consistency Post", "https://example.com/delete-consistency")
deleteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil)
deleteReq.Header.Set("Authorization", "Bearer "+user.Token)
deleteReq = testutils.WithUserContext(deleteReq, middleware.UserIDKey, user.User.ID)
deleteReq = testutils.WithURLParams(deleteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
deleteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(deleteRec, deleteReq)
deleteRequest := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil)
deleteRequest.Header.Set("Authorization", "Bearer "+user.Token)
deleteRequest = testutils.WithUserContext(deleteRequest, middleware.UserIDKey, user.User.ID)
deleteRequest = testutils.WithURLParams(deleteRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
deleteRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(deleteRecorder, deleteRequest)
assertStatus(t, deleteRec, http.StatusOK)
assertStatus(t, deleteRecorder, http.StatusOK)
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil)
getRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRec, getReq)
getRequest := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil)
getRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRecorder, getRequest)
assertStatus(t, getRec, http.StatusNotFound)
assertStatus(t, getRecorder, http.StatusNotFound)
})
t.Run("Vote_Removal_Consistency", func(t *testing.T) {
@@ -300,33 +257,33 @@ func TestIntegration_DataConsistency(t *testing.T) {
voteBody := map[string]string{"type": "up"}
body, _ := json.Marshal(voteBody)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
voteRequest := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
voteRequest.Header.Set("Content-Type", "application/json")
voteRequest.Header.Set("Authorization", "Bearer "+user.Token)
voteRequest = testutils.WithUserContext(voteRequest, middleware.UserIDKey, user.User.ID)
voteRequest = testutils.WithURLParams(voteRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRecorder, voteRequest)
assertStatus(t, voteRec, http.StatusOK)
assertStatus(t, voteRecorder, http.StatusOK)
removeVoteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
removeVoteReq.Header.Set("Authorization", "Bearer "+user.Token)
removeVoteReq = testutils.WithUserContext(removeVoteReq, middleware.UserIDKey, user.User.ID)
removeVoteReq = testutils.WithURLParams(removeVoteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
removeVoteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(removeVoteRec, removeVoteReq)
removeVoteRequest := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
removeVoteRequest.Header.Set("Authorization", "Bearer "+user.Token)
removeVoteRequest = testutils.WithUserContext(removeVoteRequest, middleware.UserIDKey, user.User.ID)
removeVoteRequest = testutils.WithURLParams(removeVoteRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
removeVoteRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(removeVoteRecorder, removeVoteRequest)
assertStatus(t, removeVoteRec, http.StatusOK)
assertStatus(t, removeVoteRecorder, http.StatusOK)
getVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesReq.Header.Set("Authorization", "Bearer "+user.Token)
getVotesReq = testutils.WithUserContext(getVotesReq, middleware.UserIDKey, user.User.ID)
getVotesReq = testutils.WithURLParams(getVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRec, getVotesReq)
getVotesRequest := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesRequest.Header.Set("Authorization", "Bearer "+user.Token)
getVotesRequest = testutils.WithUserContext(getVotesRequest, middleware.UserIDKey, user.User.ID)
getVotesRequest = testutils.WithURLParams(getVotesRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRecorder, getVotesRequest)
votesResponse := assertJSONResponse(t, getVotesRec, http.StatusOK)
votesResponse := assertJSONResponse(t, getVotesRecorder, http.StatusOK)
if votesResponse == nil {
return
}

View File

@@ -22,41 +22,39 @@ func TestIntegration_EdgeCases(t *testing.T) {
expiredToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2MDAwMDAwMDB9.expired"
req := httptest.NewRequest("GET", "/api/auth/me", nil)
req.Header.Set("Authorization", "Bearer "+expiredToken)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/api/auth/me", nil)
request.Header.Set("Authorization", "Bearer "+expiredToken)
recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, recorder, http.StatusUnauthorized)
})
t.Run("Concurrent_Vote_Operations", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user1 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "vote_user1", "vote1@example.com")
firstUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "vote_user1", "vote1@example.com")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user1.User.ID, "Concurrent Vote Post", "https://example.com/concurrent")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, firstUser.User.ID, "Concurrent Vote Post", "https://example.com/concurrent")
var wg sync.WaitGroup
errors := make(chan error, 10)
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for range 5 {
wg.Go(func() {
voteBody := map[string]string{"type": "up"}
body, _ := json.Marshal(voteBody)
req := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user1.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user1.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
errors <- fmt.Errorf("unexpected status: %d", rec.Code)
request := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+firstUser.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, firstUser.User.ID)
request = testutils.WithURLParams(request, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(recorder, request)
if recorder.Code != http.StatusOK {
errors <- fmt.Errorf("unexpected status: %d", recorder.Code)
}
}()
})
}
wg.Wait()
@@ -72,8 +70,8 @@ func TestIntegration_EdgeCases(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "large_user", "large@example.com")
largeContent := make([]byte, 10001)
for i := range largeContent {
largeContent[i] = 'a'
for idx := range largeContent {
largeContent[idx] = 'a'
}
postBody := map[string]string{
@@ -82,36 +80,36 @@ func TestIntegration_EdgeCases(t *testing.T) {
"content": string(largeContent),
}
body, _ := json.Marshal(postBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusBadRequest)
assertErrorResponse(t, recorder, http.StatusBadRequest)
smallContent := make([]byte, 1000)
for i := range smallContent {
smallContent[i] = 'a'
for idx := range smallContent {
smallContent[idx] = 'a'
}
postBody2 := map[string]string{
secondPostBody := map[string]string{
"title": "Small Post",
"url": "https://example.com/small",
"content": string(smallContent),
}
body2, _ := json.Marshal(postBody2)
req2 := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body2))
req2.Header.Set("Content-Type", "application/json")
req2.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID)
rec2 := httptest.NewRecorder()
secondBody, _ := json.Marshal(secondPostBody)
secondRequest := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(secondBody))
secondRequest.Header.Set("Content-Type", "application/json")
secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
secondRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec2, req2)
ctx.Router.ServeHTTP(secondRecorder, secondRequest)
assertStatus(t, rec2, http.StatusCreated)
assertStatus(t, secondRecorder, http.StatusCreated)
})
t.Run("Malformed_JSON_Payloads", func(t *testing.T) {
@@ -127,15 +125,15 @@ func TestIntegration_EdgeCases(t *testing.T) {
}
for _, payload := range malformedPayloads {
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBufferString(payload))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBufferString(payload))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusBadRequest)
assertErrorResponse(t, recorder, http.StatusBadRequest)
}
})
@@ -148,38 +146,36 @@ func TestIntegration_EdgeCases(t *testing.T) {
voteBody := map[string]string{"type": "up"}
body, _ := json.Marshal(voteBody)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
assertStatus(t, voteRec, http.StatusOK)
voteRequest := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
voteRequest.Header.Set("Content-Type", "application/json")
voteRequest.Header.Set("Authorization", "Bearer "+user.Token)
voteRequest = testutils.WithUserContext(voteRequest, middleware.UserIDKey, user.User.ID)
voteRequest = testutils.WithURLParams(voteRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRecorder, voteRequest)
assertStatus(t, voteRecorder, http.StatusOK)
var wg sync.WaitGroup
for i := 0; i < 3; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
}()
for range 3 {
wg.Go(func() {
request := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
request = testutils.WithURLParams(request, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(recorder, request)
})
}
wg.Wait()
getVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesReq.Header.Set("Authorization", "Bearer "+user.Token)
getVotesReq = testutils.WithUserContext(getVotesReq, middleware.UserIDKey, user.User.ID)
getVotesReq = testutils.WithURLParams(getVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRec, getVotesReq)
getVotesRequest := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesRequest.Header.Set("Authorization", "Bearer "+user.Token)
getVotesRequest = testutils.WithUserContext(getVotesRequest, middleware.UserIDKey, user.User.ID)
getVotesRequest = testutils.WithURLParams(getVotesRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRecorder, getVotesRequest)
votesResponse := assertJSONResponse(t, getVotesRec, http.StatusOK)
votesResponse := assertJSONResponse(t, getVotesRecorder, http.StatusOK)
if votesResponse != nil {
if data, ok := votesResponse["data"].(map[string]any); ok {
if votes, ok := data["votes"].([]any); ok {

View File

@@ -1,14 +1,10 @@
package integration
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"goyco/internal/database"
"goyco/internal/middleware"
"goyco/internal/testutils"
)
@@ -19,19 +15,14 @@ func TestIntegration_EmailService(t *testing.T) {
t.Run("Registration_Email_Sent", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{
reqBody := map[string]any{
"username": "email_reg_user",
"email": "email_reg@example.com",
"password": "SecurePass123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, router, "/api/auth/register", reqBody)
router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusCreated)
assertStatus(t, request, http.StatusCreated)
token := ctx.Suite.EmailSender.VerificationToken()
if token == "" {
@@ -52,15 +43,10 @@ func TestIntegration_EmailService(t *testing.T) {
t.Fatalf("Failed to create user: %v", err)
}
reqBody := map[string]string{
reqBody := map[string]any{
"username_or_email": "email_reset_user",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/forgot-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
makePostRequestWithJSON(t, router, "/api/auth/forgot-password", reqBody)
token := ctx.Suite.EmailSender.PasswordResetToken()
if token == "" {
@@ -72,17 +58,9 @@ func TestIntegration_EmailService(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "email_del_user", "email_del@example.com")
reqBody := map[string]string{}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makeDeleteRequest(t, router, "/api/auth/account", user, nil)
router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
token := ctx.Suite.EmailSender.DeletionToken()
if token == "" {
@@ -94,17 +72,10 @@ func TestIntegration_EmailService(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "email_change_user", "email_change@example.com")
reqBody := map[string]string{
reqBody := map[string]any{
"email": "newemail@example.com",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/auth/email", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
makePutRequest(t, router, "/api/auth/email", reqBody, user, nil)
token := ctx.Suite.EmailSender.VerificationToken()
if token == "" {
@@ -115,17 +86,12 @@ func TestIntegration_EmailService(t *testing.T) {
t.Run("Email_Template_Content", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{
reqBody := map[string]any{
"username": "template_user",
"email": "template@example.com",
"password": "SecurePass123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
makePostRequestWithJSON(t, router, "/api/auth/register", reqBody)
token := ctx.Suite.EmailSender.VerificationToken()
if token == "" {

View File

@@ -1,7 +1,6 @@
package integration
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
@@ -10,7 +9,7 @@ import (
"strings"
"testing"
"goyco/internal/middleware"
"goyco/internal/database"
"goyco/internal/testutils"
)
@@ -20,46 +19,34 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Run("Complete_Registration_To_Post_Creation_Journey", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
registerBody := map[string]string{
registerRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "journey_user",
"email": "journey@example.com",
"password": "SecurePass123!",
}
body, _ := json.Marshal(registerBody)
registerReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
registerReq.Header.Set("Content-Type", "application/json")
registerRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(registerRec, registerReq)
})
assertStatus(t, registerRec, http.StatusCreated)
assertStatus(t, registerRequest, http.StatusCreated)
verificationToken := ctx.Suite.EmailSender.VerificationToken()
if verificationToken == "" {
t.Fatal("Verification token not sent")
}
confirmReq := httptest.NewRequest("GET", "/api/auth/confirm?token="+url.QueryEscape(verificationToken), nil)
confirmRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(confirmRec, confirmReq)
confirmRequest := makeGetRequest(t, ctx.Router, "/api/auth/confirm?token="+url.QueryEscape(verificationToken))
assertStatus(t, confirmRec, http.StatusOK)
assertStatus(t, confirmRequest, http.StatusOK)
loginBody := map[string]string{
loginRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/login", map[string]any{
"username": "journey_user",
"password": "SecurePass123!",
}
loginBodyBytes, _ := json.Marshal(loginBody)
loginReq := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBuffer(loginBodyBytes))
loginReq.Header.Set("Content-Type", "application/json")
loginRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(loginRec, loginReq)
})
loginResponse := assertJSONResponse(t, loginRec, http.StatusOK)
loginResponse := assertJSONResponse(t, loginRequest, http.StatusOK)
if loginResponse == nil {
return
}
data, ok := loginResponse["data"].(map[string]any)
data, ok := getDataFromResponse(loginResponse)
if !ok {
t.Fatal("Login response missing data")
}
@@ -67,8 +54,8 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
var token string
if accessToken, ok := data["access_token"].(string); ok && accessToken != "" {
token = accessToken
} else if tokenVal, ok := data["token"].(string); ok && tokenVal != "" {
token = tokenVal
} else if tokenValue, ok := data["token"].(string); ok && tokenValue != "" {
token = tokenValue
} else {
t.Fatal("Login response missing access_token or token")
}
@@ -90,25 +77,19 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Fatalf("Login response missing user.id. Data: %+v", data)
}
postBody := map[string]string{
postBodyBytes, _ := json.Marshal(map[string]any{
"title": "Journey Test Post",
"url": "https://example.com/journey",
"content": "Test content",
}
postBodyBytes, _ := json.Marshal(postBody)
postReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(postBodyBytes))
postReq.Header.Set("Content-Type", "application/json")
postReq.Header.Set("Authorization", "Bearer "+token)
postReq = testutils.WithUserContext(postReq, middleware.UserIDKey, uint(userID))
postRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(postRec, postReq)
})
postRequest := makeAuthenticatedRequest(t, ctx.Router, "POST", "/api/posts", postBodyBytes, &authenticatedUser{User: &database.User{ID: uint(userID)}, Token: token}, nil)
postResponse := assertJSONResponse(t, postRec, http.StatusCreated)
postResponse := assertJSONResponse(t, postRequest, http.StatusCreated)
if postResponse == nil {
return
}
postData, ok := postResponse["data"].(map[string]any)
postData, ok := getDataFromResponse(postResponse)
if !ok {
t.Fatal("Post response missing data")
}
@@ -118,56 +99,39 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Fatal("Post response missing id")
}
getPostReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%.0f", postID), nil)
getPostRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getPostRec, getPostReq)
getPostRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%.0f", postID))
assertStatus(t, getPostRec, http.StatusOK)
assertStatus(t, getPostRequest, http.StatusOK)
})
t.Run("Complete_Password_Reset_Journey", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "reset_journey_user", "reset_journey@example.com")
resetBody := map[string]string{
resetRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/forgot-password", map[string]any{
"username_or_email": "reset_journey@example.com",
}
resetBodyBytes, _ := json.Marshal(resetBody)
resetReq := httptest.NewRequest("POST", "/api/auth/forgot-password", bytes.NewBuffer(resetBodyBytes))
resetReq.Header.Set("Content-Type", "application/json")
resetRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(resetRec, resetReq)
})
assertStatus(t, resetRec, http.StatusOK)
assertStatus(t, resetRequest, http.StatusOK)
resetToken := ctx.Suite.EmailSender.GetLastPasswordResetToken()
if resetToken == "" {
t.Fatal("Password reset token not sent")
}
newPasswordBody := map[string]string{
newPasswordRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/reset-password", map[string]any{
"token": resetToken,
"new_password": "NewSecurePass123!",
}
newPasswordBodyBytes, _ := json.Marshal(newPasswordBody)
newPasswordReq := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(newPasswordBodyBytes))
newPasswordReq.Header.Set("Content-Type", "application/json")
newPasswordRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(newPasswordRec, newPasswordReq)
})
assertStatus(t, newPasswordRec, http.StatusOK)
assertStatus(t, newPasswordRequest, http.StatusOK)
loginBody := map[string]string{
loginRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/login", map[string]any{
"username": "reset_journey_user",
"password": "NewSecurePass123!",
}
loginBodyBytes, _ := json.Marshal(loginBody)
loginReq := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBuffer(loginBodyBytes))
loginReq.Header.Set("Content-Type", "application/json")
loginRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(loginRec, loginReq)
})
assertStatus(t, loginRec, http.StatusOK)
assertStatus(t, loginRequest, http.StatusOK)
})
t.Run("Complete_Vote_And_Unvote_Journey", func(t *testing.T) {
@@ -176,40 +140,21 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Journey Post", "https://example.com/vote-journey")
voteBody := map[string]string{"type": "up"}
voteBodyBytes, _ := json.Marshal(voteBody)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBodyBytes))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
voteRequest := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, voteRequest, http.StatusOK)
assertStatus(t, voteRec, http.StatusOK)
getVotesRequest := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/votes", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesReq.Header.Set("Authorization", "Bearer "+user.Token)
getVotesReq = testutils.WithUserContext(getVotesReq, middleware.UserIDKey, user.User.ID)
getVotesReq = testutils.WithURLParams(getVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRec, getVotesReq)
votesResponse := assertJSONResponse(t, getVotesRec, http.StatusOK)
votesResponse := assertJSONResponse(t, getVotesRequest, http.StatusOK)
if votesResponse == nil {
return
}
if data, ok := votesResponse["data"].(map[string]any); ok {
if data, ok := getDataFromResponse(votesResponse); ok {
if votes, ok := data["votes"].([]any); ok && len(votes) > 0 {
unvoteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
unvoteReq.Header.Set("Authorization", "Bearer "+user.Token)
unvoteReq = testutils.WithUserContext(unvoteReq, middleware.UserIDKey, user.User.ID)
unvoteReq = testutils.WithURLParams(unvoteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
unvoteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(unvoteRec, unvoteReq)
unvoteRequest := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, unvoteRec, http.StatusOK)
assertStatus(t, unvoteRequest, http.StatusOK)
}
}
})
@@ -221,31 +166,31 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
csrfToken := getCSRFToken(t, pageRouter, "/register")
reqBody := url.Values{}
reqBody.Set("username", "page_journey_user")
reqBody.Set("email", "page_journey@example.com")
reqBody.Set("password", "SecurePass123!")
reqBody.Set("password_confirm", "SecurePass123!")
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("username", "page_journey_user")
requestBody.Set("email", "page_journey@example.com")
requestBody.Set("password", "SecurePass123!")
requestBody.Set("password_confirm", "SecurePass123!")
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
pageRouter.ServeHTTP(rec, req)
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
pageRouter.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
verificationToken := pageCtx.Suite.EmailSender.VerificationToken()
if verificationToken == "" {
t.Fatal("Verification token not sent")
}
confirmReq := httptest.NewRequest("GET", "/confirm?token="+url.QueryEscape(verificationToken), nil)
confirmRec := httptest.NewRecorder()
pageRouter.ServeHTTP(confirmRec, confirmReq)
confirmRequest := httptest.NewRequest("GET", "/confirm?token="+url.QueryEscape(verificationToken), nil)
confirmRecorder := httptest.NewRecorder()
pageRouter.ServeHTTP(confirmRecorder, confirmRequest)
assertStatusRange(t, confirmRec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, confirmRecorder, http.StatusOK, http.StatusSeeOther)
loginCSRFToken := getCSRFToken(t, pageRouter, "/login")
@@ -254,15 +199,15 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
loginBody.Set("password", "SecurePass123!")
loginBody.Set("csrf_token", loginCSRFToken)
loginReq := httptest.NewRequest("POST", "/login", strings.NewReader(loginBody.Encode()))
loginReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
loginReq.AddCookie(&http.Cookie{Name: "csrf_token", Value: loginCSRFToken})
loginRec := httptest.NewRecorder()
pageRouter.ServeHTTP(loginRec, loginReq)
loginRequest := httptest.NewRequest("POST", "/login", strings.NewReader(loginBody.Encode()))
loginRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded")
loginRequest.AddCookie(&http.Cookie{Name: "csrf_token", Value: loginCSRFToken})
loginRecorder := httptest.NewRecorder()
pageRouter.ServeHTTP(loginRecorder, loginRequest)
assertStatus(t, loginRec, http.StatusSeeOther)
assertStatus(t, loginRecorder, http.StatusSeeOther)
loginCookies := loginRec.Result().Cookies()
loginCookies := loginRecorder.Result().Cookies()
var authToken string
for _, cookie := range loginCookies {
if cookie.Name == "auth_token" {
@@ -275,37 +220,31 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Fatal("Auth token not set after login")
}
homeReq := httptest.NewRequest("GET", "/", nil)
homeReq.AddCookie(&http.Cookie{Name: "auth_token", Value: authToken})
homeRec := httptest.NewRecorder()
pageRouter.ServeHTTP(homeRec, homeReq)
homeRequest := httptest.NewRequest("GET", "/", nil)
homeRequest.AddCookie(&http.Cookie{Name: "auth_token", Value: authToken})
homeRecorder := httptest.NewRecorder()
pageRouter.ServeHTTP(homeRecorder, homeRequest)
assertStatus(t, homeRec, http.StatusOK)
assertStatus(t, homeRecorder, http.StatusOK)
})
t.Run("Complete_Post_Creation_And_Update_Journey", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "post_update_journey_user", "post_update_journey@example.com")
postBody := map[string]string{
postBodyBytes, _ := json.Marshal(map[string]any{
"title": "Original Title",
"url": "https://example.com/original",
"content": "Original content",
}
postBodyBytes, _ := json.Marshal(postBody)
postReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(postBodyBytes))
postReq.Header.Set("Content-Type", "application/json")
postReq.Header.Set("Authorization", "Bearer "+user.Token)
postReq = testutils.WithUserContext(postReq, middleware.UserIDKey, user.User.ID)
postRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(postRec, postReq)
})
postRequest := makeAuthenticatedRequest(t, ctx.Router, "POST", "/api/posts", postBodyBytes, user, nil)
postResponse := assertJSONResponse(t, postRec, http.StatusCreated)
postResponse := assertJSONResponse(t, postRequest, http.StatusCreated)
if postResponse == nil {
return
}
postData, ok := postResponse["data"].(map[string]any)
postData, ok := getDataFromResponse(postResponse)
if !ok {
t.Fatal("Post response missing data")
}
@@ -315,34 +254,25 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Fatal("Post response missing id")
}
updateBody := map[string]string{
updateBodyBytes, _ := json.Marshal(map[string]any{
"title": "Updated Title",
"content": "Updated content",
}
updateBodyBytes, _ := json.Marshal(updateBody)
updateReq := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%.0f", postID), bytes.NewBuffer(updateBodyBytes))
updateReq.Header.Set("Content-Type", "application/json")
updateReq.Header.Set("Authorization", "Bearer "+user.Token)
updateReq = testutils.WithUserContext(updateReq, middleware.UserIDKey, user.User.ID)
updateReq = testutils.WithURLParams(updateReq, map[string]string{"id": fmt.Sprintf("%.0f", postID)})
updateRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(updateRec, updateReq)
})
updateRequest := makeAuthenticatedRequest(t, ctx.Router, "PUT", fmt.Sprintf("/api/posts/%.0f", postID), updateBodyBytes, user, map[string]string{"id": fmt.Sprintf("%.0f", postID)})
updateResponse := assertJSONResponse(t, updateRec, http.StatusOK)
updateResponse := assertJSONResponse(t, updateRequest, http.StatusOK)
if updateResponse == nil {
return
}
getPostReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%.0f", postID), nil)
getPostRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getPostRec, getPostReq)
getPostRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%.0f", postID))
getPostResponse := assertJSONResponse(t, getPostRec, http.StatusOK)
getPostResponse := assertJSONResponse(t, getPostRequest, http.StatusOK)
if getPostResponse == nil {
return
}
if data, ok := getPostResponse["data"].(map[string]any); ok {
if data, ok := getDataFromResponse(getPostResponse); ok {
if post, ok := data["post"].(map[string]any); ok {
if title, ok := post["title"].(string); ok && title != "Updated Title" {
t.Errorf("Post title not updated: expected 'Updated Title', got '%s'", title)

View File

@@ -2,7 +2,6 @@ package integration
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
@@ -19,58 +18,46 @@ func TestIntegration_ErrorPropagation(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "json_error_user", "json_error@example.com")
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("invalid json{")))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("invalid json{")))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusBadRequest)
assertErrorResponse(t, recorder, http.StatusBadRequest)
})
t.Run("Validation_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{
reqBody := map[string]any{
"username": "",
"email": "invalid-email",
"password": "weak",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", reqBody)
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusBadRequest)
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Database_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "db_error_user", "db_error@example.com")
reqBody := map[string]string{
reqBody := map[string]any{
"title": "Test Post",
"url": "https://example.com/test",
"content": "Test content",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makePostRequest(t, ctx.Router, "/api/posts", reqBody, user, nil)
ctx.Router.ServeHTTP(rec, req)
if rec.Code == http.StatusInternalServerError {
assertErrorResponse(t, rec, http.StatusInternalServerError)
if request.Code == http.StatusInternalServerError {
assertErrorResponse(t, request, http.StatusInternalServerError)
} else {
assertStatus(t, rec, http.StatusCreated)
assertStatus(t, request, http.StatusCreated)
}
})
@@ -78,34 +65,23 @@ func TestIntegration_ErrorPropagation(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "notfound_error_user", "notfound_error@example.com")
req := httptest.NewRequest("GET", "/api/posts/999999", nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": "999999"})
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, ctx.Router, "/api/posts/999999", user, map[string]string{"id": "999999"})
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusNotFound)
assertErrorResponse(t, request, http.StatusNotFound)
})
t.Run("Unauthorized_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{
reqBody := map[string]any{
"title": "Test Post",
"url": "https://example.com/test",
"content": "Test content",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, ctx.Router, "/api/posts", reqBody)
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Forbidden_Error_Propagation", func(t *testing.T) {
@@ -115,79 +91,59 @@ func TestIntegration_ErrorPropagation(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, owner.User.ID, "Forbidden Post", "https://example.com/forbidden")
updateBody := map[string]string{
updateBody := map[string]any{
"title": "Updated Title",
"content": "Updated content",
}
body, _ := json.Marshal(updateBody)
req := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+otherUser.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, otherUser.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
request := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), updateBody, otherUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusForbidden)
assertErrorResponse(t, request, http.StatusForbidden)
})
t.Run("Service_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{
reqBody := map[string]any{
"username": "existing_user",
"email": "existing@example.com",
"password": "SecurePass123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", reqBody)
assertStatus(t, rec, http.StatusCreated)
assertStatus(t, request, http.StatusCreated)
req = httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
request = makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", reqBody)
assertStatusRange(t, rec, http.StatusBadRequest, http.StatusConflict)
assertErrorResponse(t, rec, rec.Code)
assertStatusRange(t, request, http.StatusBadRequest, http.StatusConflict)
assertErrorResponse(t, request, request.Code)
})
t.Run("Middleware_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("{}")))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer expired.invalid.token")
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("{}")))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer expired.invalid.token")
recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, recorder, http.StatusUnauthorized)
})
t.Run("Handler_Error_Response_Format", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
req := httptest.NewRequest("GET", "/api/nonexistent", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, ctx.Router, "/api/nonexistent")
ctx.Router.ServeHTTP(rec, req)
if rec.Code == http.StatusNotFound {
if rec.Header().Get("Content-Type") == "application/json" {
assertErrorResponse(t, rec, http.StatusNotFound)
} else {
if rec.Body.Len() == 0 {
if request.Code == http.StatusNotFound {
if request.Header().Get("Content-Type") == "application/json" {
assertErrorResponse(t, request, http.StatusNotFound)
} else if request.Body.Len() == 0 {
t.Error("Expected error response body")
}
}
}
})
}

View File

@@ -11,55 +11,35 @@ import (
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"goyco/internal/database"
"goyco/internal/handlers"
"goyco/internal/middleware"
"goyco/internal/repositories"
"goyco/internal/services"
"goyco/internal/testutils"
"github.com/golang-jwt/jwt/v5"
)
func TestIntegration_Handlers(t *testing.T) {
suite := testutils.NewServiceSuite(t)
authService, err := services.NewAuthFacadeForTest(testutils.AppTestConfig, suite.UserRepo, suite.PostRepo, suite.DeletionRepo, suite.RefreshTokenRepo, suite.EmailSender)
if err != nil {
t.Fatalf("Failed to create auth service: %v", err)
}
voteService := services.NewVoteService(suite.VoteRepo, suite.PostRepo, suite.DB)
emailSender := suite.EmailSender
userRepo := suite.UserRepo
postRepo := suite.PostRepo
titleFetcher := suite.TitleFetcher
authHandler := handlers.NewAuthHandler(authService, userRepo)
postHandler := handlers.NewPostHandler(postRepo, titleFetcher, voteService)
voteHandler := handlers.NewVoteHandler(voteService)
userHandler := handlers.NewUserHandler(userRepo, authService)
ctx := setupTestContext(t)
authService := ctx.AuthService
emailSender := ctx.Suite.EmailSender
userRepo := ctx.Suite.UserRepo
postRepo := ctx.Suite.PostRepo
t.Run("Auth_Handler_Complete_Workflow", func(t *testing.T) {
emailSender.Reset()
registerData := map[string]string{
registerResponse := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "handler_user",
"email": "handler@example.com",
"password": "SecurePass123!",
}
registerBody, _ := json.Marshal(registerData)
registerReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(registerBody))
registerReq.Header.Set("Content-Type", "application/json")
registerResp := httptest.NewRecorder()
authHandler.Register(registerResp, registerReq)
if registerResp.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", registerResp.Code)
})
if registerResponse.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", registerResponse.Code)
}
var registerPayload map[string]any
if err := json.Unmarshal(registerResp.Body.Bytes(), &registerPayload); err != nil {
t.Fatalf("Failed to decode register response: %v", err)
}
registerPayload := assertJSONResponse(t, registerResponse, http.StatusCreated)
if success, _ := registerPayload["success"].(bool); !success {
t.Fatalf("Expected register response success, got %v", registerPayload)
}
@@ -78,11 +58,9 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatalf("Failed to update user with mock token: %v", err)
}
confirmReq := httptest.NewRequest(http.MethodGet, "/api/auth/confirm?token="+url.QueryEscape(mockToken), nil)
confirmResp := httptest.NewRecorder()
authHandler.ConfirmEmail(confirmResp, confirmReq)
if confirmResp.Code != http.StatusOK {
t.Fatalf("Expected 200 when confirming email via handler, got %d", confirmResp.Code)
confirmResponse := makeGetRequest(t, ctx.Router, "/api/auth/confirm?token="+url.QueryEscape(mockToken))
if confirmResponse.Code != http.StatusOK {
t.Fatalf("Expected 200 when confirming email via handler, got %d", confirmResponse.Code)
}
loginSeed := createAuthenticatedUser(t, authService, userRepo, "auth_handler_login", "auth_handler_login@example.com")
@@ -92,92 +70,60 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatalf("Service login failed for seeded user: %v", err)
}
meReq := httptest.NewRequest("GET", "/api/auth/me", nil)
meReq.Header.Set("Authorization", "Bearer "+loginAuth.AccessToken)
meReq = testutils.WithUserContext(meReq, middleware.UserIDKey, loginSeed.User.ID)
meResp := httptest.NewRecorder()
authHandler.Me(meResp, meReq)
if meResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", meResp.Code)
meResponse := makeAuthenticatedGetRequest(t, ctx.Router, "/api/auth/me", &authenticatedUser{User: loginSeed.User, Token: loginAuth.AccessToken}, nil)
if meResponse.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", meResponse.Code)
}
})
t.Run("Auth_Handler_Security_Validation", func(t *testing.T) {
emailSender.Reset()
weakData := map[string]string{
weakResponse := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "weak_user",
"email": "weak@example.com",
"password": "123",
}
weakBody, _ := json.Marshal(weakData)
weakReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(weakBody))
weakReq.Header.Set("Content-Type", "application/json")
weakResp := httptest.NewRecorder()
authHandler.Register(weakResp, weakReq)
if weakResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for weak password, got %d", weakResp.Code)
})
if weakResponse.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for weak password, got %d", weakResponse.Code)
}
var weakErrorResp map[string]any
if err := json.Unmarshal(weakResp.Body.Bytes(), &weakErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := weakErrorResp["success"].(bool); success {
weakErrorResponse := assertJSONResponse(t, weakResponse, http.StatusBadRequest)
if success, _ := weakErrorResponse["success"].(bool); success {
t.Error("Expected error response to have success=false")
}
if errorMsg, ok := weakErrorResp["error"].(string); !ok || errorMsg == "" {
if errorMsg, ok := weakErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain validation error message")
}
invalidData := map[string]string{
invalidResponse := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "invalid_user",
"email": "not-an-email",
"password": "SecurePass123!",
}
invalidBody, _ := json.Marshal(invalidData)
invalidReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(invalidBody))
invalidReq.Header.Set("Content-Type", "application/json")
invalidResp := httptest.NewRecorder()
authHandler.Register(invalidResp, invalidReq)
if invalidResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid email, got %d", invalidResp.Code)
})
if invalidResponse.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid email, got %d", invalidResponse.Code)
}
var invalidEmailErrorResp map[string]any
if err := json.Unmarshal(invalidResp.Body.Bytes(), &invalidEmailErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := invalidEmailErrorResp["success"].(bool); success {
invalidEmailErrorResponse := assertJSONResponse(t, invalidResponse, http.StatusBadRequest)
if success, _ := invalidEmailErrorResponse["success"].(bool); success {
t.Error("Expected error response to have success=false")
}
if errorMsg, ok := invalidEmailErrorResp["error"].(string); !ok || errorMsg == "" {
if errorMsg, ok := invalidEmailErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain validation error message")
}
incompleteData := map[string]string{
incompleteResponse := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "incomplete_user",
}
incompleteBody, _ := json.Marshal(incompleteData)
incompleteReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(incompleteBody))
incompleteReq.Header.Set("Content-Type", "application/json")
incompleteResp := httptest.NewRecorder()
authHandler.Register(incompleteResp, incompleteReq)
if incompleteResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for missing fields, got %d", incompleteResp.Code)
})
if incompleteResponse.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for missing fields, got %d", incompleteResponse.Code)
}
var incompleteErrorResp map[string]any
if err := json.Unmarshal(incompleteResp.Body.Bytes(), &incompleteErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := incompleteErrorResp["success"].(bool); success {
incompleteErrorResponse := assertJSONResponse(t, incompleteResponse, http.StatusBadRequest)
if success, _ := incompleteErrorResponse["success"].(bool); success {
t.Error("Expected error response to have success=false")
}
if errorMsg, ok := incompleteErrorResp["error"].(string); !ok || errorMsg == "" {
if errorMsg, ok := incompleteErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain validation error message")
}
})
@@ -186,28 +132,17 @@ func TestIntegration_Handlers(t *testing.T) {
emailSender.Reset()
user := createAuthenticatedUser(t, authService, userRepo, "post_user", "post@example.com")
postData := map[string]string{
postResponse := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "Handler Test Post",
"url": "https://example.com/handler-test",
"content": "This is a handler test post",
}
postBody, _ := json.Marshal(postData)
postReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(postBody))
postReq.Header.Set("Content-Type", "application/json")
postReq.Header.Set("Authorization", "Bearer "+user.Token)
postReq = testutils.WithUserContext(postReq, middleware.UserIDKey, user.User.ID)
postResp := httptest.NewRecorder()
postHandler.CreatePost(postResp, postReq)
if postResp.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", postResp.Code)
}, user, nil)
if postResponse.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", postResponse.Code)
}
var postResult map[string]any
if err := json.Unmarshal(postResp.Body.Bytes(), &postResult); err != nil {
t.Fatalf("Failed to decode post response: %v", err)
}
postDetails, ok := postResult["data"].(map[string]any)
postResult := assertJSONResponse(t, postResponse, http.StatusCreated)
postDetails, ok := getDataFromResponse(postResult)
if !ok {
t.Fatalf("Expected data object in post response, got %v", postResult)
}
@@ -216,87 +151,49 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatal("Expected post ID in response")
}
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", int(postID)), nil)
getReq = testutils.WithURLParams(getReq, map[string]string{"id": fmt.Sprintf("%d", int(postID))})
getResp := httptest.NewRecorder()
postHandler.GetPost(getResp, getReq)
if getResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getResp.Code)
getResponse := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", int(postID)))
if getResponse.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getResponse.Code)
}
postsReq := httptest.NewRequest("GET", "/api/posts", nil)
postsResp := httptest.NewRecorder()
postHandler.GetPosts(postsResp, postsReq)
if postsResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", postsResp.Code)
postsResponse := makeGetRequest(t, ctx.Router, "/api/posts")
if postsResponse.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", postsResponse.Code)
}
searchReq := httptest.NewRequest("GET", "/api/posts/search?q=handler", nil)
searchResp := httptest.NewRecorder()
postHandler.SearchPosts(searchResp, searchReq)
if searchResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", searchResp.Code)
searchResponse := makeGetRequest(t, ctx.Router, "/api/posts/search?q=handler")
if searchResponse.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", searchResponse.Code)
}
})
t.Run("Post_Handler_Security_Validation", func(t *testing.T) {
emailSender.Reset()
postData := map[string]string{
postResponse := makePostRequestWithJSON(t, ctx.Router, "/api/posts", map[string]any{
"title": "Unauthorized Post",
"url": "https://example.com/unauthorized",
"content": "This should fail",
}
postBody, _ := json.Marshal(postData)
postReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(postBody))
postReq.Header.Set("Content-Type", "application/json")
postResp := httptest.NewRecorder()
postHandler.CreatePost(postResp, postReq)
if postResp.Code != http.StatusUnauthorized {
t.Errorf("Expected status 401 for unauthenticated post creation, got %d", postResp.Code)
}
var authErrorResp map[string]any
if err := json.Unmarshal(postResp.Body.Bytes(), &authErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := authErrorResp["success"].(bool); success {
})
authErrorResponse := assertJSONResponse(t, postResponse, http.StatusUnauthorized)
if success, _ := authErrorResponse["success"].(bool); success {
t.Error("Expected error response to have success=false")
}
if errorMsg, ok := authErrorResp["error"].(string); !ok || errorMsg == "" {
if errorMsg, ok := authErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain authentication error message")
}
user := createAuthenticatedUser(t, authService, userRepo, "security_user", "security@example.com")
invalidData := map[string]string{
invalidResponse := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "",
"url": "not-a-url",
"content": "Invalid post",
}
invalidBody, _ := json.Marshal(invalidData)
invalidReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(invalidBody))
invalidReq.Header.Set("Content-Type", "application/json")
invalidReq.Header.Set("Authorization", "Bearer "+user.Token)
invalidReq = testutils.WithUserContext(invalidReq, middleware.UserIDKey, user.User.ID)
invalidResp := httptest.NewRecorder()
postHandler.CreatePost(invalidResp, invalidReq)
if invalidResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid post data, got %d", invalidResp.Code)
}
var postValidationErrorResp map[string]any
if err := json.Unmarshal(invalidResp.Body.Bytes(), &postValidationErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := postValidationErrorResp["success"].(bool); success {
}, user, nil)
postValidationErrorResponse := assertJSONResponse(t, invalidResponse, http.StatusBadRequest)
if success, _ := postValidationErrorResponse["success"].(bool); success {
t.Error("Expected error response to have success=false")
}
if errorMsg, ok := postValidationErrorResp["error"].(string); !ok || errorMsg == "" {
if errorMsg, ok := postValidationErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain validation error message")
}
})
@@ -306,161 +203,100 @@ func TestIntegration_Handlers(t *testing.T) {
user := createAuthenticatedUser(t, authService, userRepo, "vote_handler_user", "vote_handler@example.com")
post := testutils.CreatePostWithRepo(t, postRepo, user.User.ID, "Vote Handler Test Post", "https://example.com/vote-handler")
voteData := map[string]string{
"type": "up",
}
voteBody, _ := json.Marshal(voteData)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBody))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteResp := httptest.NewRecorder()
voteResponse := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, voteResponse, http.StatusOK)
voteHandler.CastVote(voteResp, voteReq)
if voteResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", voteResp.Code)
}
getVoteResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, getVoteResponse, http.StatusOK)
getVoteReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
getVoteReq.Header.Set("Authorization", "Bearer "+user.Token)
getVoteReq = testutils.WithUserContext(getVoteReq, middleware.UserIDKey, user.User.ID)
getVoteReq = testutils.WithURLParams(getVoteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVoteResp := httptest.NewRecorder()
getPostVotesResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/votes", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, getPostVotesResponse, http.StatusOK)
voteHandler.GetUserVote(getVoteResp, getVoteReq)
if getVoteResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getVoteResp.Code)
}
getPostVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getPostVotesReq.Header.Set("Authorization", "Bearer "+user.Token)
getPostVotesReq = testutils.WithUserContext(getPostVotesReq, middleware.UserIDKey, user.User.ID)
getPostVotesReq = testutils.WithURLParams(getPostVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getPostVotesResp := httptest.NewRecorder()
voteHandler.GetPostVotes(getPostVotesResp, getPostVotesReq)
if getPostVotesResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getPostVotesResp.Code)
}
removeVoteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
removeVoteReq.Header.Set("Authorization", "Bearer "+user.Token)
removeVoteReq = testutils.WithUserContext(removeVoteReq, middleware.UserIDKey, user.User.ID)
removeVoteReq = testutils.WithURLParams(removeVoteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
removeVoteResp := httptest.NewRecorder()
voteHandler.RemoveVote(removeVoteResp, removeVoteReq)
if removeVoteResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", removeVoteResp.Code)
}
removeVoteResponse := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, removeVoteResponse, http.StatusOK)
})
t.Run("User_Handler_Complete_Workflow", func(t *testing.T) {
emailSender.Reset()
user := createAuthenticatedUser(t, authService, userRepo, "user_handler_user", "user_handler@example.com")
usersReq := httptest.NewRequest("GET", "/api/users", nil)
usersReq.Header.Set("Authorization", "Bearer "+user.Token)
usersReq = testutils.WithUserContext(usersReq, middleware.UserIDKey, user.User.ID)
usersResp := httptest.NewRecorder()
usersResponse := makeAuthenticatedGetRequest(t, ctx.Router, "/api/users", user, nil)
assertStatus(t, usersResponse, http.StatusOK)
userHandler.GetUsers(usersResp, usersReq)
if usersResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", usersResp.Code)
}
getUserResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
assertStatus(t, getUserResponse, http.StatusOK)
getUserReq := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d", user.User.ID), nil)
getUserReq.Header.Set("Authorization", "Bearer "+user.Token)
getUserReq = testutils.WithUserContext(getUserReq, middleware.UserIDKey, user.User.ID)
getUserReq = testutils.WithURLParams(getUserReq, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
getUserResp := httptest.NewRecorder()
userHandler.GetUser(getUserResp, getUserReq)
if getUserResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getUserResp.Code)
}
getUserPostsReq := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d/posts", user.User.ID), nil)
getUserPostsReq.Header.Set("Authorization", "Bearer "+user.Token)
getUserPostsReq = testutils.WithUserContext(getUserPostsReq, middleware.UserIDKey, user.User.ID)
getUserPostsReq = testutils.WithURLParams(getUserPostsReq, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
getUserPostsResp := httptest.NewRecorder()
userHandler.GetUserPosts(getUserPostsResp, getUserPostsReq)
if getUserPostsResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getUserPostsResp.Code)
}
getUserPostsResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d/posts", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
assertStatus(t, getUserPostsResponse, http.StatusOK)
})
t.Run("Error_Handling_Invalid_Requests", func(t *testing.T) {
invalidJSONReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer([]byte("invalid json")))
invalidJSONReq.Header.Set("Content-Type", "application/json")
invalidJSONResp := httptest.NewRecorder()
middleware.StopAllRateLimiters()
ctx.Suite.EmailSender.Reset()
authHandler.Register(invalidJSONResp, invalidJSONReq)
if invalidJSONResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid JSON, got %d", invalidJSONResp.Code)
}
var jsonErrorResp map[string]any
if err := json.Unmarshal(invalidJSONResp.Body.Bytes(), &jsonErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := jsonErrorResp["success"].(bool); success {
invalidJSONResponse := makeRequest(t, ctx.Router, "POST", "/api/auth/register", []byte("invalid json"), map[string]string{"Content-Type": "application/json"})
jsonErrorResponse := assertJSONResponse(t, invalidJSONResponse, http.StatusBadRequest)
if success, _ := jsonErrorResponse["success"].(bool); success {
t.Error("Expected error response to have success=false")
}
if errorMsg, ok := jsonErrorResp["error"].(string); !ok || errorMsg == "" {
if errorMsg, ok := jsonErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain JSON parsing error message")
}
missingCTData := map[string]string{
"username": "missing_ct_user",
"email": "missing_ct@example.com",
"username": uniqueTestUsername(t, "missing_ct"),
"email": uniqueTestEmail(t, "missing_ct"),
"password": "SecurePass123!",
}
missingCTBody, _ := json.Marshal(missingCTData)
missingCTReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(missingCTBody))
missingCTResp := httptest.NewRecorder()
missingCTRequest := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(missingCTBody))
missingCTResponse := httptest.NewRecorder()
authHandler.Register(missingCTResp, missingCTReq)
if missingCTResp.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", missingCTResp.Code)
ctx.Router.ServeHTTP(missingCTResponse, missingCTRequest)
if missingCTResponse.Code == http.StatusTooManyRequests {
var rateLimitResponse map[string]any
if err := json.Unmarshal(missingCTResponse.Body.Bytes(), &rateLimitResponse); err != nil {
t.Errorf("Rate limited but response is not valid JSON: %v", err)
} else {
t.Logf("Rate limit hit (expected in full test suite run), but request was processed correctly (not rejected as invalid JSON)")
}
} else if missingCTResponse.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", missingCTResponse.Code)
}
invalidEndpointReq := httptest.NewRequest("GET", "/api/invalid/endpoint", nil)
invalidEndpointResp := httptest.NewRecorder()
invalidEndpointRequest := httptest.NewRequest("GET", "/api/invalid/endpoint", nil)
invalidEndpointResponse := httptest.NewRecorder()
authHandler.Me(invalidEndpointResp, invalidEndpointReq)
if invalidEndpointResp.Code == http.StatusOK {
ctx.Router.ServeHTTP(invalidEndpointResponse, invalidEndpointRequest)
if invalidEndpointResponse.Code == http.StatusOK {
t.Error("Expected error for invalid endpoint")
}
})
t.Run("Security_Authentication_Bypass", func(t *testing.T) {
meReq := httptest.NewRequest("GET", "/api/auth/me", nil)
meResp := httptest.NewRecorder()
meRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
meResponse := httptest.NewRecorder()
authHandler.Me(meResp, meReq)
if meResp.Code == http.StatusOK {
ctx.Router.ServeHTTP(meResponse, meRequest)
if meResponse.Code == http.StatusOK {
t.Error("Expected error for unauthenticated request")
}
invalidTokenReq := httptest.NewRequest("GET", "/api/auth/me", nil)
invalidTokenReq.Header.Set("Authorization", "Bearer invalid-token")
invalidTokenResp := httptest.NewRecorder()
invalidTokenRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
invalidTokenRequest.Header.Set("Authorization", "Bearer invalid-token")
invalidTokenResponse := httptest.NewRecorder()
authHandler.Me(invalidTokenResp, invalidTokenReq)
if invalidTokenResp.Code == http.StatusOK {
ctx.Router.ServeHTTP(invalidTokenResponse, invalidTokenRequest)
if invalidTokenResponse.Code == http.StatusOK {
t.Error("Expected error for invalid token")
}
malformedTokenReq := httptest.NewRequest("GET", "/api/auth/me", nil)
malformedTokenReq.Header.Set("Authorization", "InvalidFormat token")
malformedTokenResp := httptest.NewRecorder()
malformedTokenRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
malformedTokenRequest.Header.Set("Authorization", "InvalidFormat token")
malformedTokenResponse := httptest.NewRecorder()
authHandler.Me(malformedTokenResp, malformedTokenReq)
if malformedTokenResp.Code == http.StatusOK {
ctx.Router.ServeHTTP(malformedTokenResponse, malformedTokenRequest)
if malformedTokenResponse.Code == http.StatusOK {
t.Error("Expected error for malformed token")
}
})
@@ -468,32 +304,21 @@ func TestIntegration_Handlers(t *testing.T) {
t.Run("Security_Input_Sanitization", func(t *testing.T) {
user := createAuthenticatedUser(t, authService, userRepo, "xss_user", "xss@example.com")
xssData := map[string]string{
xssResponse := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "<script>alert('xss')</script>",
"url": "https://example.com/xss",
"content": "XSS test content",
}
xssBody, _ := json.Marshal(xssData)
xssReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(xssBody))
xssReq.Header.Set("Content-Type", "application/json")
xssReq.Header.Set("Authorization", "Bearer "+user.Token)
xssReq = testutils.WithUserContext(xssReq, middleware.UserIDKey, user.User.ID)
xssResp := httptest.NewRecorder()
postHandler.CreatePost(xssResp, xssReq)
if xssResp.Code != http.StatusCreated {
t.Errorf("Expected status 201 for XSS sanitization, got %d", xssResp.Code)
"content": "<script>alert('xss')</script>",
}, user, nil)
if xssResponse.Code != http.StatusCreated {
t.Errorf("Expected status 201 for XSS sanitization, got %d", xssResponse.Code)
}
var xssResult map[string]any
if err := json.Unmarshal(xssResp.Body.Bytes(), &xssResult); err != nil {
t.Fatalf("Failed to decode XSS response: %v", err)
}
xssResult := assertJSONResponse(t, xssResponse, http.StatusCreated)
if success, _ := xssResult["success"].(bool); !success {
t.Error("Expected XSS response to have success=true")
}
data, ok := xssResult["data"].(map[string]any)
data, ok := getDataFromResponse(xssResult)
if !ok {
t.Fatalf("Expected data object in XSS response, got %T", xssResult["data"])
}
@@ -522,32 +347,21 @@ func TestIntegration_Handlers(t *testing.T) {
t.Errorf("Expected script tags to be HTML-escaped in content, got: %s", content)
}
sqlData := map[string]string{
sqlResponse := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "'; DROP TABLE posts; --",
"url": "https://example.com/sql",
"content": "SQL injection test",
}
sqlBody, _ := json.Marshal(sqlData)
sqlReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(sqlBody))
sqlReq.Header.Set("Content-Type", "application/json")
sqlReq.Header.Set("Authorization", "Bearer "+user.Token)
sqlReq = testutils.WithUserContext(sqlReq, middleware.UserIDKey, user.User.ID)
sqlResp := httptest.NewRecorder()
postHandler.CreatePost(sqlResp, sqlReq)
if sqlResp.Code != http.StatusCreated {
t.Errorf("Expected status 201 for SQL injection sanitization, got %d", sqlResp.Code)
}, user, nil)
if sqlResponse.Code != http.StatusCreated {
t.Errorf("Expected status 201 for SQL injection sanitization, got %d", sqlResponse.Code)
}
var sqlResult map[string]any
if err := json.Unmarshal(sqlResp.Body.Bytes(), &sqlResult); err != nil {
t.Fatalf("Failed to decode SQL response: %v", err)
}
sqlResult := assertJSONResponse(t, sqlResponse, http.StatusCreated)
if success, _ := sqlResult["success"].(bool); !success {
t.Error("Expected SQL response to have success=true")
}
sqlResponseData, ok := sqlResult["data"].(map[string]any)
sqlResponseData, ok := getDataFromResponse(sqlResult)
if !ok {
t.Fatalf("Expected data object in SQL response, got %T", sqlResult["data"])
}
@@ -579,63 +393,31 @@ func TestIntegration_Handlers(t *testing.T) {
t.Run("Authorization_User_Access_Control", func(t *testing.T) {
emailSender.Reset()
user1 := createAuthenticatedUser(t, authService, userRepo, "auth_user1", "auth1@example.com")
user2 := createAuthenticatedUser(t, authService, userRepo, "auth_user2", "auth2@example.com")
firstUser := createAuthenticatedUser(t, authService, userRepo, "auth_user1", "auth1@example.com")
secondUser := createAuthenticatedUser(t, authService, userRepo, "auth_user2", "auth2@example.com")
post := testutils.CreatePostWithRepo(t, postRepo, user1.User.ID, "Private Post", "https://example.com/private")
post := testutils.CreatePostWithRepo(t, postRepo, firstUser.User.ID, "Private Post", "https://example.com/private")
getPostReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil)
getPostReq.Header.Set("Authorization", "Bearer "+user2.Token)
getPostReq = testutils.WithUserContext(getPostReq, middleware.UserIDKey, user2.User.ID)
getPostReq = testutils.WithURLParams(getPostReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getPostResp := httptest.NewRecorder()
getPostResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
testutils.AssertHTTPStatus(t, getPostResponse, http.StatusOK)
postHandler.GetPost(getPostResp, getPostReq)
testutils.AssertHTTPStatus(t, getPostResp, http.StatusOK)
updateResponse := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{"title": "Updated Title"}, secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
testutils.AssertHTTPStatus(t, updateResponse, http.StatusForbidden)
updateData := map[string]string{
"title": "Updated Title",
}
updateBody, _ := json.Marshal(updateData)
updateReq := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(updateBody))
updateReq.Header.Set("Content-Type", "application/json")
updateReq.Header.Set("Authorization", "Bearer "+user2.Token)
updateReq = testutils.WithUserContext(updateReq, middleware.UserIDKey, user2.User.ID)
updateReq = testutils.WithURLParams(updateReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
updateResp := httptest.NewRecorder()
postHandler.UpdatePost(updateResp, updateReq)
testutils.AssertHTTPStatus(t, updateResp, http.StatusForbidden)
deleteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil)
deleteReq.Header.Set("Authorization", "Bearer "+user2.Token)
deleteReq = testutils.WithUserContext(deleteReq, middleware.UserIDKey, user2.User.ID)
deleteReq = testutils.WithURLParams(deleteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
deleteResp := httptest.NewRecorder()
postHandler.DeletePost(deleteResp, deleteReq)
testutils.AssertHTTPStatus(t, deleteResp, http.StatusForbidden)
deleteResponse := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
testutils.AssertHTTPStatus(t, deleteResponse, http.StatusForbidden)
})
t.Run("Authorization_Vote_Access_Control", func(t *testing.T) {
emailSender.Reset()
user1 := createAuthenticatedUser(t, authService, userRepo, "vote_auth_user1", "vote_auth1@example.com")
user2 := createAuthenticatedUser(t, authService, userRepo, "vote_auth_user2", "vote_auth2@example.com")
firstUser := createAuthenticatedUser(t, authService, userRepo, "vote_auth_user1", "vote_auth1@example.com")
secondUser := createAuthenticatedUser(t, authService, userRepo, "vote_auth_user2", "vote_auth2@example.com")
post := testutils.CreatePostWithRepo(t, postRepo, user1.User.ID, "Vote Auth Post", "https://example.com/vote-auth")
post := testutils.CreatePostWithRepo(t, postRepo, firstUser.User.ID, "Vote Auth Post", "https://example.com/vote-auth")
voteData := map[string]string{"type": "up"}
voteBody, _ := json.Marshal(voteData)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBody))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user2.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user2.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteResp := httptest.NewRecorder()
voteHandler.CastVote(voteResp, voteReq)
if voteResp.Code != http.StatusOK {
t.Errorf("Users should be able to vote on any post, got %d", voteResp.Code)
voteResponse := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
if voteResponse.Code != http.StatusOK {
t.Errorf("Users should be able to vote on any post, got %d", voteResponse.Code)
}
})
@@ -664,12 +446,8 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatalf("Failed to generate expired token: %v", err)
}
meReq := httptest.NewRequest("GET", "/api/auth/me", nil)
meReq.Header.Set("Authorization", "Bearer "+expiredToken)
meResp := httptest.NewRecorder()
authHandler.Me(meResp, meReq)
testutils.AssertHTTPStatus(t, meResp, http.StatusUnauthorized)
meResponse := makeRequest(t, ctx.Router, "GET", "/api/auth/me", nil, map[string]string{"Authorization": "Bearer " + expiredToken})
testutils.AssertHTTPStatus(t, meResponse, http.StatusUnauthorized)
})
t.Run("Authorization_Token_Tampering", func(t *testing.T) {
@@ -678,12 +456,12 @@ func TestIntegration_Handlers(t *testing.T) {
tamperedToken := user.Token[:len(user.Token)-5] + "XXXXX"
meReq := httptest.NewRequest("GET", "/api/auth/me", nil)
meReq.Header.Set("Authorization", "Bearer "+tamperedToken)
meResp := httptest.NewRecorder()
meRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
meRequest.Header.Set("Authorization", "Bearer "+tamperedToken)
meResponse := httptest.NewRecorder()
authHandler.Me(meResp, meReq)
testutils.AssertHTTPStatus(t, meResp, http.StatusUnauthorized)
ctx.Router.ServeHTTP(meResponse, meRequest)
testutils.AssertHTTPStatus(t, meResponse, http.StatusUnauthorized)
})
t.Run("Authorization_Session_Version_Mismatch", func(t *testing.T) {
@@ -711,12 +489,12 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatalf("Failed to generate invalid token: %v", err)
}
meReq := httptest.NewRequest("GET", "/api/auth/me", nil)
meReq.Header.Set("Authorization", "Bearer "+invalidToken)
meResp := httptest.NewRecorder()
meRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
meRequest.Header.Set("Authorization", "Bearer "+invalidToken)
meResponse := httptest.NewRecorder()
authHandler.Me(meResp, meReq)
testutils.AssertHTTPStatus(t, meResp, http.StatusUnauthorized)
ctx.Router.ServeHTTP(meResponse, meRequest)
testutils.AssertHTTPStatus(t, meResponse, http.StatusUnauthorized)
})
}
@@ -748,11 +526,9 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
}
voteService := services.NewVoteService(voteRepo, postRepo, db)
apiHandler := handlers.NewAPIHandlerWithMonitoring(testutils.AppTestConfig, postRepo, userRepo, voteService, db, monitor)
t.Run("Health endpoint includes database monitoring", func(t *testing.T) {
user := &database.User{
Username: "monitoring_user",
Email: "monitoring@example.com",
@@ -765,7 +541,6 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
recorder := httptest.NewRecorder()
apiHandler.GetHealth(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
var response map[string]any
@@ -777,7 +552,7 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
t.Error("Expected success to be true")
}
data, ok := response["data"].(map[string]any)
data, ok := getDataFromResponse(response)
if !ok {
t.Fatal("Expected data to be a map")
}
@@ -813,7 +588,6 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
recorder := httptest.NewRecorder()
apiHandler.GetMetrics(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
var response map[string]any
@@ -825,7 +599,7 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
t.Error("Expected success to be true")
}
data, ok := response["data"].(map[string]any)
data, ok := getDataFromResponse(response)
if !ok {
t.Fatal("Expected data to be a map")
}

View File

@@ -1,6 +1,7 @@
package integration
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
@@ -282,14 +283,14 @@ func setupPageHandlerTestContext(t *testing.T) *testContext {
func getCSRFToken(t *testing.T, router http.Handler, path string, cookies ...*http.Cookie) string {
t.Helper()
req := httptest.NewRequest("GET", path, nil)
request := httptest.NewRequest("GET", path, nil)
for _, cookie := range cookies {
req.AddCookie(cookie)
request.AddCookie(cookie)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
cookieList := rec.Result().Cookies()
cookieList := recorder.Result().Cookies()
for _, cookie := range cookieList {
if cookie.Name == "csrf_token" {
return cookie.Value
@@ -299,32 +300,32 @@ func getCSRFToken(t *testing.T, router http.Handler, path string, cookies ...*ht
return ""
}
func assertJSONResponse(t *testing.T, rec *httptest.ResponseRecorder, expectedStatus int) map[string]any {
func assertJSONResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedStatus int) map[string]any {
t.Helper()
if rec.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, rec.Code, rec.Body.String())
if recorder.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, recorder.Code, recorder.Body.String())
return nil
}
var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode response: %v. Body: %s", err, rec.Body.String())
if err := json.NewDecoder(recorder.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode response: %v. Body: %s", err, recorder.Body.String())
return nil
}
return response
}
func assertErrorResponse(t *testing.T, rec *httptest.ResponseRecorder, expectedStatus int) {
func assertErrorResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedStatus int) {
t.Helper()
if rec.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, rec.Code, rec.Body.String())
if recorder.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, recorder.Code, recorder.Body.String())
return
}
var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode error response: %v. Body: %s", err, rec.Body.String())
if err := json.NewDecoder(recorder.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode error response: %v. Body: %s", err, recorder.Body.String())
return
}
@@ -335,23 +336,23 @@ func assertErrorResponse(t *testing.T, rec *httptest.ResponseRecorder, expectedS
}
}
func assertStatus(t *testing.T, rec *httptest.ResponseRecorder, expectedStatus int) {
func assertStatus(t *testing.T, recorder *httptest.ResponseRecorder, expectedStatus int) {
t.Helper()
if rec.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, rec.Code, rec.Body.String())
if recorder.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, recorder.Code, recorder.Body.String())
}
}
func assertStatusRange(t *testing.T, rec *httptest.ResponseRecorder, minStatus, maxStatus int) {
func assertStatusRange(t *testing.T, recorder *httptest.ResponseRecorder, minStatus, maxStatus int) {
t.Helper()
if rec.Code < minStatus || rec.Code > maxStatus {
t.Errorf("Expected status between %d and %d, got %d. Body: %s", minStatus, maxStatus, rec.Code, rec.Body.String())
if recorder.Code < minStatus || recorder.Code > maxStatus {
t.Errorf("Expected status between %d and %d, got %d. Body: %s", minStatus, maxStatus, recorder.Code, recorder.Body.String())
}
}
func assertCookie(t *testing.T, rec *httptest.ResponseRecorder, name, expectedValue string) {
func assertCookie(t *testing.T, recorder *httptest.ResponseRecorder, name, expectedValue string) {
t.Helper()
cookies := rec.Result().Cookies()
cookies := recorder.Result().Cookies()
for _, cookie := range cookies {
if cookie.Name == name {
if expectedValue != "" && cookie.Value != expectedValue {
@@ -363,9 +364,9 @@ func assertCookie(t *testing.T, rec *httptest.ResponseRecorder, name, expectedVa
t.Errorf("Expected cookie %s not found", name)
}
func assertCookieCleared(t *testing.T, rec *httptest.ResponseRecorder, name string) {
func assertCookieCleared(t *testing.T, recorder *httptest.ResponseRecorder, name string) {
t.Helper()
cookies := rec.Result().Cookies()
cookies := recorder.Result().Cookies()
for _, cookie := range cookies {
if cookie.Name == name {
if cookie.Value != "" {
@@ -376,21 +377,17 @@ func assertCookieCleared(t *testing.T, rec *httptest.ResponseRecorder, name stri
}
}
func assertHeader(t *testing.T, rec *httptest.ResponseRecorder, name, expectedValue string) {
func assertHeader(t *testing.T, recorder *httptest.ResponseRecorder, name string) {
t.Helper()
actualValue := rec.Header().Get(name)
if expectedValue == "" {
actualValue := recorder.Header().Get(name)
if actualValue == "" {
t.Errorf("Expected header %s to be present", name)
}
} else if actualValue != expectedValue {
t.Errorf("Expected header %s=%s, got %s", name, expectedValue, actualValue)
}
}
func assertHeaderContains(t *testing.T, rec *httptest.ResponseRecorder, name, substring string) {
func assertHeaderContains(t *testing.T, recorder *httptest.ResponseRecorder, name, substring string) {
t.Helper()
actualValue := rec.Header().Get(name)
actualValue := recorder.Header().Get(name)
if !strings.Contains(actualValue, substring) {
t.Errorf("Expected header %s to contain %s, got %s", name, substring, actualValue)
}
@@ -450,3 +447,83 @@ func createUserWithCleanup(t *testing.T, ctx *testContext, username, email strin
})
return user
}
func makeRequest(t *testing.T, router http.Handler, method, path string, body []byte, headers map[string]string) *httptest.ResponseRecorder {
t.Helper()
var requestBody *bytes.Buffer
if body != nil {
requestBody = bytes.NewBuffer(body)
} else {
requestBody = bytes.NewBuffer(nil)
}
request := httptest.NewRequest(method, path, requestBody)
for key, value := range headers {
request.Header.Set(key, value)
}
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
return recorder
}
func makeAuthenticatedRequest(t *testing.T, router http.Handler, method, path string, body []byte, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
var requestBody *bytes.Buffer
if body != nil {
requestBody = bytes.NewBuffer(body)
} else {
requestBody = bytes.NewBuffer(nil)
}
request := httptest.NewRequest(method, path, requestBody)
request.Header.Set("Authorization", "Bearer "+user.Token)
if body != nil {
request.Header.Set("Content-Type", "application/json")
}
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
if urlParams != nil {
request = testutils.WithURLParams(request, urlParams)
}
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
return recorder
}
func makeGetRequest(t *testing.T, router http.Handler, path string) *httptest.ResponseRecorder {
t.Helper()
return makeRequest(t, router, "GET", path, nil, nil)
}
func makeAuthenticatedGetRequest(t *testing.T, router http.Handler, path string, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
return makeAuthenticatedRequest(t, router, "GET", path, nil, user, urlParams)
}
func makePostRequest(t *testing.T, router http.Handler, path string, body map[string]any, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
bodyBytes, _ := json.Marshal(body)
return makeAuthenticatedRequest(t, router, "POST", path, bodyBytes, user, urlParams)
}
func makePutRequest(t *testing.T, router http.Handler, path string, body map[string]any, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
bodyBytes, _ := json.Marshal(body)
return makeAuthenticatedRequest(t, router, "PUT", path, bodyBytes, user, urlParams)
}
func makeDeleteRequest(t *testing.T, router http.Handler, path string, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
return makeAuthenticatedRequest(t, router, "DELETE", path, nil, user, urlParams)
}
func makePostRequestWithJSON(t *testing.T, router http.Handler, path string, body map[string]any) *httptest.ResponseRecorder {
t.Helper()
bodyBytes, _ := json.Marshal(body)
return makeRequest(t, router, "POST", path, bodyBytes, map[string]string{"Content-Type": "application/json"})
}
func getDataFromResponse(response map[string]any) (map[string]any, bool) {
if response == nil {
return nil, false
}
data, ok := response["data"].(map[string]any)
return data, ok
}

View File

@@ -22,26 +22,26 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, authService, ctx.Suite.UserRepo, "settings_email_user", "settings_email@example.com")
getReq := httptest.NewRequest("GET", "/settings", nil)
getReq.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
getRequest := httptest.NewRequest("GET", "/settings", nil)
getRequest.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
getRecorder := httptest.NewRecorder()
router.ServeHTTP(getRecorder, getRequest)
csrfToken := getCSRFToken(t, router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token})
reqBody := url.Values{}
reqBody.Set("email", "newemail@example.com")
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("email", "newemail@example.com")
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/settings/email", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/settings/email", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
})
t.Run("Settings_Username_Update_Form", func(t *testing.T) {
@@ -51,19 +51,19 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token})
reqBody := url.Values{}
reqBody.Set("username", "new_username")
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("username", "new_username")
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/settings/username", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/settings/username", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
})
t.Run("Settings_Password_Update_Form", func(t *testing.T) {
@@ -74,20 +74,20 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, freshCtx.Router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token})
reqBody := url.Values{}
reqBody.Set("current_password", "SecurePass123!")
reqBody.Set("new_password", "NewSecurePass123!")
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("current_password", "SecurePass123!")
requestBody.Set("new_password", "NewSecurePass123!")
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/settings/password", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/settings/password", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req)
freshCtx.Router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
})
t.Run("Logout_Page_Handler", func(t *testing.T) {
@@ -98,19 +98,19 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, freshCtx.Router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token})
reqBody := url.Values{}
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/logout", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/logout", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req)
freshCtx.Router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusSeeOther)
assertCookieCleared(t, rec, "auth_token")
assertStatus(t, recorder, http.StatusSeeOther)
assertCookieCleared(t, recorder, "auth_token")
})
t.Run("Resend_Verification_Page_Handler", func(t *testing.T) {
@@ -120,18 +120,18 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, freshCtx.Router, "/resend-verification")
reqBody := url.Values{}
reqBody.Set("email", "resend_page@example.com")
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("email", "resend_page@example.com")
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/resend-verification", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/resend-verification", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req)
freshCtx.Router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
})
t.Run("Post_Vote_Page_Handler", func(t *testing.T) {
@@ -142,26 +142,26 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
post := testutils.CreatePostWithRepo(t, freshCtx.Suite.PostRepo, user.User.ID, "Vote Page Test", "https://example.com/vote-page")
getReq := httptest.NewRequest("GET", fmt.Sprintf("/posts/%d", post.ID), nil)
getRec := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(getRec, getReq)
getRequest := httptest.NewRequest("GET", fmt.Sprintf("/posts/%d", post.ID), nil)
getRecorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(getRecorder, getRequest)
csrfToken := getCSRFToken(t, freshCtx.Router, fmt.Sprintf("/posts/%d", post.ID))
reqBody := url.Values{}
reqBody.Set("action", "up")
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("action", "up")
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", fmt.Sprintf("/posts/%d/vote", post.ID), strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", fmt.Sprintf("/posts/%d/vote", post.ID), strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
request = testutils.WithURLParams(request, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req)
freshCtx.Router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
})
t.Run("Login_Page_Handler_Workflow", func(t *testing.T) {
@@ -172,20 +172,20 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, freshCtx.Router, "/login")
reqBody := url.Values{}
reqBody.Set("username", "login_page_user")
reqBody.Set("password", "SecurePass123!")
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("username", "login_page_user")
requestBody.Set("password", "SecurePass123!")
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/login", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/login", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req)
freshCtx.Router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusSeeOther)
assertCookie(t, rec, "auth_token", "")
assertStatus(t, recorder, http.StatusSeeOther)
assertCookie(t, recorder, "auth_token", "")
})
t.Run("Email_Confirmation_Page_Handler", func(t *testing.T) {
@@ -198,11 +198,11 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
token = "test-token"
}
req := httptest.NewRequest("GET", "/confirm?token="+url.QueryEscape(token), nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/confirm?token="+url.QueryEscape(token), nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
})
}

View File

@@ -16,63 +16,62 @@ func TestIntegration_PageHandler(t *testing.T) {
router := ctx.Router
t.Run("Home_Page_Renders", func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, recorder, http.StatusOK)
if !strings.Contains(rec.Body.String(), "<html") {
if !strings.Contains(recorder.Body.String(), "<html") {
t.Error("Expected HTML content")
}
})
t.Run("Login_Form_Renders", func(t *testing.T) {
req := httptest.NewRequest("GET", "/login", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/login", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, recorder, http.StatusOK)
body := rec.Body.String()
body := recorder.Body.String()
if !strings.Contains(body, "login") && !strings.Contains(body, "Login") {
t.Error("Expected login form content")
}
})
t.Run("Register_Form_Renders", func(t *testing.T) {
req := httptest.NewRequest("GET", "/register", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/register", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatus(t, recorder, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
body := rec.Body.String()
body := recorder.Body.String()
if !strings.Contains(body, "register") && !strings.Contains(body, "Register") {
t.Error("Expected register form content")
}
})
t.Run("PageHandler_With_CSRF_Token", func(t *testing.T) {
req := httptest.NewRequest("GET", "/register", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/register", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertCookie(t, rec, "csrf_token", "")
assertCookie(t, recorder, "csrf_token", "")
})
t.Run("PageHandler_Form_Submission", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
getReq := httptest.NewRequest("GET", "/register", nil)
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
getRequest := httptest.NewRequest("GET", "/register", nil)
getRecorder := httptest.NewRecorder()
router.ServeHTTP(getRecorder, getRequest)
cookies := getRec.Result().Cookies()
cookies := getRecorder.Result().Cookies()
var csrfCookie *http.Cookie
for _, cookie := range cookies {
if cookie.Name == "csrf_token" {
@@ -85,33 +84,33 @@ func TestIntegration_PageHandler(t *testing.T) {
t.Fatal("Expected CSRF cookie")
}
reqBody := url.Values{}
reqBody.Set("username", "page_form_user")
reqBody.Set("email", "page_form@example.com")
reqBody.Set("password", "SecurePass123!")
reqBody.Set("csrf_token", csrfCookie.Value)
requestBody := url.Values{}
requestBody.Set("username", "page_form_user")
requestBody.Set("email", "page_form@example.com")
requestBody.Set("password", "SecurePass123!")
requestBody.Set("csrf_token", csrfCookie.Value)
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
})
t.Run("PageHandler_Authenticated_Access", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createUserWithCleanup(t, ctx, "page_auth_user", "page_auth@example.com")
req := httptest.NewRequest("GET", "/settings", nil)
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/settings", nil)
request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, recorder, http.StatusOK)
})
t.Run("PageHandler_Post_Display", func(t *testing.T) {
@@ -120,34 +119,34 @@ func TestIntegration_PageHandler(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Page Test Post", "https://example.com/page-test")
req := httptest.NewRequest("GET", "/posts/"+fmt.Sprintf("%d", post.ID), nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/posts/"+fmt.Sprintf("%d", post.ID), nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, recorder, http.StatusOK)
body := rec.Body.String()
body := recorder.Body.String()
if !strings.Contains(body, "Page Test Post") {
t.Error("Expected post title in page")
}
})
t.Run("PageHandler_Search_Page", func(t *testing.T) {
req := httptest.NewRequest("GET", "/search?q=test", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/search?q=test", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, recorder, http.StatusOK)
})
t.Run("PageHandler_Error_Handling", func(t *testing.T) {
req := httptest.NewRequest("GET", "/nonexistent", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/nonexistent", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusNotFound)
assertStatus(t, recorder, http.StatusNotFound)
})
}

View File

@@ -1,8 +1,6 @@
package integration
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
@@ -32,17 +30,12 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
t.Fatalf("Failed to create user: %v", err)
}
reqBody := map[string]string{
reqBody := map[string]any{
"username_or_email": "reset_user",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/forgot-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, router, "/api/auth/forgot-password", reqBody)
router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, request, http.StatusOK)
if response != nil {
if success, ok := response["success"].(bool); !ok || !success {
t.Error("Expected success=true")
@@ -77,18 +70,13 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
t.Fatal("Expected password reset token")
}
reqBody := map[string]string{
reqBody := map[string]any{
"token": resetToken,
"new_password": "NewPassword123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, router, "/api/auth/reset-password", reqBody)
router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
loginResult, err := ctx.AuthService.Login("reset_complete_user", "NewPassword123!")
if err != nil {
@@ -120,14 +108,14 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
reqBody := url.Values{}
reqBody.Set("username_or_email", "page_reset_user")
reqBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/forgot-password", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/forgot-password", strings.NewReader(reqBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
pageRouter.ServeHTTP(rec, req)
pageRouter.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
resetToken := pageCtx.Suite.EmailSender.PasswordResetToken()
if resetToken == "" {
@@ -166,33 +154,23 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
t.Fatalf("Failed to update user: %v", err)
}
reqBody := map[string]string{
reqBody := map[string]any{
"token": resetToken,
"new_password": "NewPassword123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, router, "/api/auth/reset-password", reqBody)
router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusBadRequest)
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("PasswordReset_InvalidToken", func(t *testing.T) {
reqBody := map[string]string{
reqBody := map[string]any{
"token": "invalid-token",
"new_password": "NewPassword123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, router, "/api/auth/reset-password", reqBody)
router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusBadRequest)
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("PasswordReset_WeakPassword", func(t *testing.T) {
@@ -214,18 +192,13 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
resetToken := ctx.Suite.EmailSender.PasswordResetToken()
reqBody := map[string]string{
reqBody := map[string]any{
"token": resetToken,
"new_password": "123",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, router, "/api/auth/reset-password", reqBody)
router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusBadRequest)
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("PasswordReset_EmailIntegration", func(t *testing.T) {
@@ -243,17 +216,12 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
t.Fatalf("Failed to create user: %v", err)
}
reqBody := map[string]string{
reqBody := map[string]any{
"username_or_email": "email_reset@example.com",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/forgot-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, freshCtx.Router, "/api/auth/forgot-password", reqBody)
freshCtx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
resetToken := freshCtx.Suite.EmailSender.PasswordResetToken()
if resetToken == "" {

View File

@@ -51,24 +51,24 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 2; i++ {
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
}
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests)
assertErrorResponse(t, recorder, http.StatusTooManyRequests)
assertHeader(t, rec, "Retry-After", "")
assertHeader(t, recorder, "Retry-After")
var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err == nil {
if err := json.NewDecoder(recorder.Body).Decode(&response); err == nil {
if _, exists := response["retry_after"]; !exists {
t.Error("Expected retry_after in response")
}
@@ -81,17 +81,17 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 5; i++ {
req := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("GET", "/api/posts", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
}
req := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/api/posts", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests)
assertErrorResponse(t, recorder, http.StatusTooManyRequests)
})
t.Run("Health_RateLimit_Enforced", func(t *testing.T) {
@@ -100,17 +100,17 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 3; i++ {
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("GET", "/health", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
}
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/health", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests)
assertErrorResponse(t, recorder, http.StatusTooManyRequests)
})
t.Run("Metrics_RateLimit_Enforced", func(t *testing.T) {
@@ -119,17 +119,17 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 2; i++ {
req := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("GET", "/metrics", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
}
req := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/metrics", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests)
assertErrorResponse(t, recorder, http.StatusTooManyRequests)
})
t.Run("RateLimit_Different_Endpoints_Independent", func(t *testing.T) {
@@ -139,17 +139,17 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 2; i++ {
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
}
req := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("GET", "/api/posts", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, recorder, http.StatusOK)
})
t.Run("RateLimit_With_Authentication", func(t *testing.T) {
@@ -166,20 +166,20 @@ func TestIntegration_RateLimiting(t *testing.T) {
user := createAuthenticatedUser(t, authService, suite.UserRepo, uniqueTestUsername(t, "ratelimit_auth"), uniqueTestEmail(t, "ratelimit_auth"))
for i := 0; i < 3; i++ {
req := httptest.NewRequest("GET", "/api/auth/me", nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("GET", "/api/auth/me", nil)
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
}
req := httptest.NewRequest("GET", "/api/auth/me", nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/api/auth/me", nil)
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests)
assertErrorResponse(t, recorder, http.StatusTooManyRequests)
})
}

View File

@@ -17,35 +17,29 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
router := ctx.Router
t.Run("SecurityHeaders_Present", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/health")
router.ServeHTTP(rec, req)
assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
assertHeader(t, rec, "X-Content-Type-Options", "")
assertHeader(t, rec, "X-Frame-Options", "")
assertHeader(t, rec, "X-XSS-Protection", "")
assertHeader(t, request, "X-Content-Type-Options")
assertHeader(t, request, "X-Frame-Options")
assertHeader(t, request, "X-XSS-Protection")
})
t.Run("CORS_Headers_Present", func(t *testing.T) {
req := httptest.NewRequest("OPTIONS", "/api/posts", nil)
req.Header.Set("Origin", "http://localhost:3000")
rec := httptest.NewRecorder()
request := httptest.NewRequest("OPTIONS", "/api/posts", nil)
request.Header.Set("Origin", "http://localhost:3000")
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertHeader(t, rec, "Access-Control-Allow-Origin", "")
assertHeader(t, recorder, "Access-Control-Allow-Origin")
})
t.Run("Logging_Middleware_Executes", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/health")
router.ServeHTTP(rec, req)
if rec.Code == 0 {
if request.Code == 0 {
t.Error("Expected logging middleware to execute")
}
})
@@ -53,27 +47,24 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
t.Run("RequestSizeLimit_Enforced", func(t *testing.T) {
user := createUserWithCleanup(t, ctx, "size_limit_user", "size_limit@example.com")
largeBody := strings.Repeat("a", 10*1024*1024)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBufferString(largeBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBufferString(largeBody))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Code != http.StatusRequestEntityTooLarge && rec.Code != http.StatusBadRequest {
t.Errorf("Expected status 413 or 400 for oversized request, got %d. Body: %s", rec.Code, rec.Body.String())
if recorder.Code != http.StatusRequestEntityTooLarge && recorder.Code != http.StatusBadRequest {
t.Errorf("Expected status 413 or 400 for oversized request, got %d. Body: %s", recorder.Code, recorder.Body.String())
}
})
t.Run("DBMonitoring_Active", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := makeGetRequest(t, router, "/health")
var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err == nil {
if err := json.NewDecoder(request.Body).Decode(&response); err == nil {
if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["database_stats"]; !exists {
t.Error("Expected database_stats in health response")
@@ -83,12 +74,9 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
})
t.Run("Metrics_Middleware_Executes", func(t *testing.T) {
req := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/metrics")
router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, request, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["database"]; !exists {
@@ -99,34 +87,25 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
})
t.Run("StaticFiles_Served", func(t *testing.T) {
req := httptest.NewRequest("GET", "/robots.txt", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/robots.txt")
router.ServeHTTP(rec, req)
assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
if !strings.Contains(rec.Body.String(), "User-agent") {
if !strings.Contains(request.Body.String(), "User-agent") {
t.Error("Expected robots.txt content")
}
})
t.Run("API_Routes_Accessible", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/api/posts")
router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
})
t.Run("Health_Endpoint_Accessible", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/health")
router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, request, http.StatusOK)
if response != nil {
if success, ok := response["success"].(bool); !ok || !success {
t.Error("Expected success=true in health response")
@@ -135,40 +114,33 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
})
t.Run("Middleware_Order_Correct", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/api/posts")
router.ServeHTTP(rec, req)
assertHeader(t, request, "X-Content-Type-Options")
assertHeader(t, rec, "X-Content-Type-Options", "")
if rec.Code == 0 {
if request.Code == 0 {
t.Error("Response should have status code")
}
})
t.Run("Compression_Middleware_Active", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil)
req.Header.Set("Accept-Encoding", "gzip")
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/api/posts", nil)
request.Header.Set("Accept-Encoding", "gzip")
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Header().Get("Content-Encoding") == "" {
if recorder.Header().Get("Content-Encoding") == "" {
t.Log("Compression may not be applied to small responses")
}
})
t.Run("Cache_Middleware_Active", func(t *testing.T) {
req1 := httptest.NewRequest("GET", "/api/posts", nil)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
firstRequest := makeGetRequest(t, router, "/api/posts")
req2 := httptest.NewRequest("GET", "/api/posts", nil)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
secondRequest := makeGetRequest(t, router, "/api/posts")
if rec1.Code != rec2.Code {
if firstRequest.Code != secondRequest.Code {
t.Error("Cached responses should have same status")
}
})
@@ -177,35 +149,23 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createUserWithCleanup(t, ctx, "auth_middleware_user", "auth_middleware@example.com")
req := httptest.NewRequest("GET", "/api/auth/me", nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, router, "/api/auth/me", user, nil)
router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
})
t.Run("RateLimit_Middleware_Integration", func(t *testing.T) {
rateLimitCtx := setupTestContext(t)
rateLimitRouter := rateLimitCtx.Router
for i := 0; i < 3; i++ {
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
rateLimitRouter.ServeHTTP(rec, req)
for range 3 {
request := makePostRequestWithJSON(t, rateLimitRouter, "/api/auth/login", map[string]any{"username": "test", "password": "test"})
_ = request
}
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, rateLimitRouter, "/api/auth/login", map[string]any{"username": "test", "password": "test"})
rateLimitRouter.ServeHTTP(rec, req)
if rec.Code == http.StatusTooManyRequests {
if request.Code == http.StatusTooManyRequests {
t.Log("Rate limiting is working")
}
})

View File

@@ -21,60 +21,60 @@ func TestIntegration_SessionManagement(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "session_pass_user", "session_pass@example.com")
req1 := httptest.NewRequest("GET", "/api/auth/me", nil)
req1.Header.Set("Authorization", "Bearer "+user.Token)
req1 = testutils.WithUserContext(req1, middleware.UserIDKey, user.User.ID)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
firstRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
firstRequest.Header.Set("Authorization", "Bearer "+user.Token)
firstRequest = testutils.WithUserContext(firstRequest, middleware.UserIDKey, user.User.ID)
firstRecorder := httptest.NewRecorder()
router.ServeHTTP(firstRecorder, firstRequest)
assertStatus(t, rec1, http.StatusOK)
assertStatus(t, firstRecorder, http.StatusOK)
reqBody := map[string]string{
requestBody := map[string]string{
"current_password": "SecurePass123!",
"new_password": "NewSecurePass123!",
}
body, _ := json.Marshal(reqBody)
req2 := httptest.NewRequest("PUT", "/api/auth/password", bytes.NewBuffer(body))
req2.Header.Set("Content-Type", "application/json")
req2.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
body, _ := json.Marshal(requestBody)
secondRequest := httptest.NewRequest("PUT", "/api/auth/password", bytes.NewBuffer(body))
secondRequest.Header.Set("Content-Type", "application/json")
secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondRequest)
assertStatus(t, rec2, http.StatusOK)
assertStatus(t, secondRecorder, http.StatusOK)
req3 := httptest.NewRequest("GET", "/api/auth/me", nil)
req3.Header.Set("Authorization", "Bearer "+user.Token)
req3 = testutils.WithUserContext(req3, middleware.UserIDKey, user.User.ID)
rec3 := httptest.NewRecorder()
router.ServeHTTP(rec3, req3)
thirdRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
thirdRequest.Header.Set("Authorization", "Bearer "+user.Token)
thirdRequest = testutils.WithUserContext(thirdRequest, middleware.UserIDKey, user.User.ID)
thirdRecorder := httptest.NewRecorder()
router.ServeHTTP(thirdRecorder, thirdRequest)
assertErrorResponse(t, rec3, http.StatusUnauthorized)
assertErrorResponse(t, thirdRecorder, http.StatusUnauthorized)
})
t.Run("Session_Invalidation_On_Account_Lock", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "session_lock_user", "session_lock@example.com")
req1 := httptest.NewRequest("GET", "/api/auth/me", nil)
req1.Header.Set("Authorization", "Bearer "+user.Token)
req1 = testutils.WithUserContext(req1, middleware.UserIDKey, user.User.ID)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
firstRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
firstRequest.Header.Set("Authorization", "Bearer "+user.Token)
firstRequest = testutils.WithUserContext(firstRequest, middleware.UserIDKey, user.User.ID)
firstRecorder := httptest.NewRecorder()
router.ServeHTTP(firstRecorder, firstRequest)
assertStatus(t, rec1, http.StatusOK)
assertStatus(t, firstRecorder, http.StatusOK)
if err := ctx.Suite.UserRepo.Lock(user.User.ID); err != nil {
t.Fatalf("Failed to lock user: %v", err)
}
req2 := httptest.NewRequest("GET", "/api/auth/me", nil)
req2.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
secondRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondRequest)
assertErrorResponse(t, rec2, http.StatusUnauthorized)
assertErrorResponse(t, secondRecorder, http.StatusUnauthorized)
})
t.Run("Refresh_Token_Revocation", func(t *testing.T) {
@@ -90,48 +90,48 @@ func TestIntegration_SessionManagement(t *testing.T) {
t.Fatal("Expected refresh token")
}
reqBody := map[string]string{
requestBody := map[string]string{
"refresh_token": loginResult.RefreshToken,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
body, _ := json.Marshal(requestBody)
request := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, recorder, http.StatusOK)
if err := ctx.AuthService.RevokeRefreshToken(loginResult.RefreshToken); err != nil {
t.Fatalf("Failed to revoke token: %v", err)
}
req2 := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body))
req2.Header.Set("Content-Type", "application/json")
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
secondRequest := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body))
secondRequest.Header.Set("Content-Type", "application/json")
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondRequest)
assertErrorResponse(t, rec2, http.StatusUnauthorized)
assertErrorResponse(t, secondRecorder, http.StatusUnauthorized)
})
t.Run("Multiple_Sessions_Independent", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user1 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "multi_session_user1", "multi_session1@example.com")
user2 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "multi_session_user2", "multi_session2@example.com")
firstUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "multi_session_user1", "multi_session1@example.com")
secondUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "multi_session_user2", "multi_session2@example.com")
req1 := httptest.NewRequest("GET", "/api/auth/me", nil)
req1.Header.Set("Authorization", "Bearer "+user1.Token)
req1 = testutils.WithUserContext(req1, middleware.UserIDKey, user1.User.ID)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
firstRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
firstRequest.Header.Set("Authorization", "Bearer "+firstUser.Token)
firstRequest = testutils.WithUserContext(firstRequest, middleware.UserIDKey, firstUser.User.ID)
firstRecorder := httptest.NewRecorder()
router.ServeHTTP(firstRecorder, firstRequest)
req2 := httptest.NewRequest("GET", "/api/auth/me", nil)
req2.Header.Set("Authorization", "Bearer "+user2.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user2.User.ID)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
secondRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
secondRequest.Header.Set("Authorization", "Bearer "+secondUser.Token)
secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, secondUser.User.ID)
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondRequest)
assertStatus(t, rec1, http.StatusOK)
assertStatus(t, rec2, http.StatusOK)
assertStatus(t, firstRecorder, http.StatusOK)
assertStatus(t, secondRecorder, http.StatusOK)
})
}
@@ -144,17 +144,17 @@ func TestIntegration_AccountDeletion(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "del_flow_user", "del_flow@example.com")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Test Post", "https://example.com")
reqBody := map[string]string{}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
requestBody := map[string]string{}
body, _ := json.Marshal(requestBody)
request := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, recorder, http.StatusOK)
if response == nil {
return
}
@@ -171,13 +171,13 @@ func TestIntegration_AccountDeletion(t *testing.T) {
"token": deletionToken,
}
confirmBodyBytes, _ := json.Marshal(confirmBody)
confirmReq := httptest.NewRequest("POST", "/api/auth/account/confirm", bytes.NewBuffer(confirmBodyBytes))
confirmReq.Header.Set("Content-Type", "application/json")
confirmRec := httptest.NewRecorder()
confirmRequest := httptest.NewRequest("POST", "/api/auth/account/confirm", bytes.NewBuffer(confirmBodyBytes))
confirmRequest.Header.Set("Content-Type", "application/json")
confirmRecorder := httptest.NewRecorder()
router.ServeHTTP(confirmRec, confirmReq)
router.ServeHTTP(confirmRecorder, confirmRequest)
confirmResponse := assertJSONResponse(t, confirmRec, http.StatusOK)
confirmResponse := assertJSONResponse(t, confirmRecorder, http.StatusOK)
if confirmResponse == nil {
return
}
@@ -209,17 +209,17 @@ func TestIntegration_AccountDeletion(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "del_posts_user", "del_posts@example.com")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Deletion Post", "https://example.com/deletion")
reqBody := map[string]string{}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
requestBody := map[string]string{}
body, _ := json.Marshal(requestBody)
request := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, recorder, http.StatusOK)
if response == nil {
return
}
@@ -237,13 +237,13 @@ func TestIntegration_AccountDeletion(t *testing.T) {
"delete_posts": true,
}
confirmBodyBytes, _ := json.Marshal(confirmBody)
confirmReq := httptest.NewRequest("POST", "/api/auth/account/confirm", bytes.NewBuffer(confirmBodyBytes))
confirmReq.Header.Set("Content-Type", "application/json")
confirmRec := httptest.NewRecorder()
confirmRequest := httptest.NewRequest("POST", "/api/auth/account/confirm", bytes.NewBuffer(confirmBodyBytes))
confirmRequest.Header.Set("Content-Type", "application/json")
confirmRecorder := httptest.NewRecorder()
router.ServeHTTP(confirmRec, confirmReq)
router.ServeHTTP(confirmRecorder, confirmRequest)
confirmResponse := assertJSONResponse(t, confirmRec, http.StatusOK)
confirmResponse := assertJSONResponse(t, confirmRecorder, http.StatusOK)
if confirmResponse == nil {
return
}
@@ -275,12 +275,12 @@ func TestIntegration_MetricsCollection(t *testing.T) {
router := ctx.Router
t.Run("Metrics_Endpoint_Returns_Data", func(t *testing.T) {
req := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/metrics", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, recorder, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["database"]; !exists {
@@ -294,13 +294,13 @@ func TestIntegration_MetricsCollection(t *testing.T) {
ctx.Suite.EmailSender.Reset()
createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "metrics_user", "metrics@example.com")
req := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/metrics", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err == nil {
if err := json.NewDecoder(recorder.Body).Decode(&response); err == nil {
if data, ok := response["data"].(map[string]any); ok {
if dbData, exists := data["database"].(map[string]any); exists {
if _, hasQueries := dbData["total_queries"]; !hasQueries {
@@ -323,7 +323,7 @@ func TestIntegration_ConcurrentRequests(t *testing.T) {
var wg sync.WaitGroup
errors := make(chan error, 10)
for i := 0; i < 10; i++ {
for idx := range 10 {
wg.Add(1)
go func(index int) {
defer wg.Done()
@@ -334,18 +334,18 @@ func TestIntegration_ConcurrentRequests(t *testing.T) {
"content": "Concurrent test content",
}
body, _ := json.Marshal(postBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Code != http.StatusCreated {
errors <- fmt.Errorf("Post %d failed with status %d", index, rec.Code)
if recorder.Code != http.StatusCreated {
errors <- fmt.Errorf("Post %d failed with status %d", index, recorder.Code)
}
}(i)
}(idx)
}
wg.Wait()
@@ -370,28 +370,26 @@ func TestIntegration_ConcurrentRequests(t *testing.T) {
var wg sync.WaitGroup
errors := make(chan error, 5)
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for range 5 {
wg.Go(func() {
voteBody := map[string]string{
"type": "up",
}
body, _ := json.Marshal(voteBody)
req := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
request = testutils.WithURLParams(request, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Code != http.StatusOK {
errors <- fmt.Errorf("Vote failed with status %d", rec.Code)
if recorder.Code != http.StatusOK {
errors <- fmt.Errorf("Vote failed with status %d", recorder.Code)
}
}()
})
}
wg.Wait()
@@ -411,20 +409,18 @@ func TestIntegration_ConcurrentRequests(t *testing.T) {
var wg sync.WaitGroup
errors := make(chan error, 20)
for i := 0; i < 20; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for range 20 {
wg.Go(func() {
req := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/api/posts", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Code != http.StatusOK {
errors <- fmt.Errorf("Read failed with status %d", rec.Code)
if recorder.Code != http.StatusOK {
errors <- fmt.Errorf("Read failed with status %d", recorder.Code)
}
}()
})
}
wg.Wait()

View File

@@ -38,6 +38,10 @@ func CompressionMiddlewareWithConfig(config *CompressionConfig) func(http.Handle
next.ServeHTTP(bufferedWriter, r)
if bufferedWriter.isRedirect {
return
}
if buf.Len() < config.MinSize {
bufferedWriter.flush()
w.Write(buf.Bytes())
@@ -73,9 +77,13 @@ type bufferedResponseWriter struct {
buffer *bytes.Buffer
statusCode int
headerWritten bool
isRedirect bool
}
func (brw *bufferedResponseWriter) Write(b []byte) (int, error) {
if brw.isRedirect {
return brw.ResponseWriter.Write(b)
}
if !brw.headerWritten {
brw.statusCode = http.StatusOK
}
@@ -87,6 +95,11 @@ func (brw *bufferedResponseWriter) WriteHeader(code int) {
return
}
brw.statusCode = code
if isRedirect(code) {
brw.isRedirect = true
brw.ResponseWriter.WriteHeader(code)
brw.headerWritten = true
}
}
func (brw *bufferedResponseWriter) Header() http.Header {
@@ -100,6 +113,10 @@ func (brw *bufferedResponseWriter) flush() {
}
}
func isRedirect(statusCode int) bool {
return statusCode >= 300 && statusCode < 400
}
func shouldCompress(r *http.Request, config *CompressionConfig) bool {
return r.Header.Get("Content-Encoding") == ""
}

View File

@@ -27,7 +27,14 @@ func ValidationMiddleware() func(http.Handler) http.Handler {
dto := reflect.New(dtoType).Interface()
if err := json.NewDecoder(r.Body).Decode(dto); err != nil {
http.Error(w, "Invalid JSON", http.StatusBadRequest)
response := map[string]any{
"success": false,
"error": "Invalid JSON",
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(response)
return
}
@@ -77,3 +84,7 @@ func GetDTOTypeFromContext(ctx context.Context) reflect.Type {
func GetValidatedDTOFromContext(ctx context.Context) any {
return ctx.Value(validatedDTOKey)
}
func SetValidatedDTOInContext(ctx context.Context, dto any) context.Context {
return context.WithValue(ctx, validatedDTOKey, dto)
}

View File

@@ -28,6 +28,7 @@ type UserRepository interface {
Unlock(id uint) error
GetPosts(userID uint, limit, offset int) ([]database.Post, error)
GetDeletedUsers() ([]database.User, error)
GetByUsernamePrefix(prefix string) (*database.User, error)
HardDeleteAll() (int64, error)
Count() (int64, error)
WithTx(tx *gorm.DB) UserRepository
@@ -240,6 +241,17 @@ func (r *userRepository) GetDeletedUsers() ([]database.User, error) {
return users, err
}
func (r *userRepository) GetByUsernamePrefix(prefix string) (*database.User, error) {
var user database.User
err := r.db.
Where("username LIKE ? AND email LIKE ?", prefix+"%", prefix+"%@goyco.local").
First(&user).Error
if err != nil {
return nil, err
}
return &user, nil
}
func (r *userRepository) HardDeleteAll() (int64, error) {
var totalDeleted int64
err := r.db.Transaction(func(tx *gorm.DB) error {

View File

@@ -20,6 +20,7 @@ type VoteRepository interface {
Count() (int64, error)
CountByPostID(postID uint) (int64, error)
CountByUserID(userID uint) (int64, error)
GetVoteCountsByPostID(postID uint) (upVotes int, downVotes int, err error)
WithTx(tx *gorm.DB) VoteRepository
}
@@ -144,3 +145,20 @@ func (r *voteRepository) Count() (int64, error) {
err := r.db.Model(&database.Vote{}).Count(&count).Error
return count, err
}
func (r *voteRepository) GetVoteCountsByPostID(postID uint) (int, int, error) {
var result struct {
UpVotes int64
DownVotes int64
}
err := r.db.Model(&database.Vote{}).
Select("COUNT(CASE WHEN type = ? THEN 1 END) as up_votes, COUNT(CASE WHEN type = ? THEN 1 END) as down_votes", database.VoteUp, database.VoteDown).
Where("post_id = ?", postID).
Scan(&result).Error
if err != nil {
return 0, 0, err
}
return int(result.UpVotes), int(result.DownVotes), nil
}

View File

@@ -1,8 +1,10 @@
package server
import (
"mime"
"net/http"
"path/filepath"
"strings"
"time"
"goyco/internal/config"
@@ -73,6 +75,7 @@ func NewRouter(cfg RouterConfig) http.Handler {
},
CSRFMiddleware: middleware.CSRFMiddleware(),
AuthMiddleware: middleware.NewAuth(cfg.AuthService),
ValidationMiddleware: middleware.ValidationMiddleware(),
}
if cfg.PageHandler != nil {
@@ -123,7 +126,33 @@ func NewRouter(cfg RouterConfig) http.Handler {
staticDir = "./internal/static/"
}
router.Handle("/static/*", http.StripPrefix("/static/", http.FileServer(http.Dir(staticDir))))
staticFileServer := http.StripPrefix("/static/", http.FileServer(http.Dir(staticDir)))
router.Handle("/static/*", staticFileHandler(staticFileServer))
return router
}
func staticFileHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
ext := filepath.Ext(path)
if ext == ".css" {
w.Header().Set("Content-Type", "text/css; charset=utf-8")
} else if ext == ".js" {
w.Header().Set("Content-Type", "application/javascript; charset=utf-8")
} else if ext == ".json" {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
} else if ext == ".ico" {
w.Header().Set("Content-Type", "image/x-icon")
} else if strings.HasPrefix(mime.TypeByExtension(ext), "image/") {
w.Header().Set("Content-Type", mime.TypeByExtension(ext))
} else if strings.HasPrefix(mime.TypeByExtension(ext), "font/") {
w.Header().Set("Content-Type", mime.TypeByExtension(ext))
} else if mimeType := mime.TypeByExtension(ext); mimeType != "" {
w.Header().Set("Content-Type", mimeType)
}
next.ServeHTTP(w, r)
})
}

View File

@@ -3,6 +3,7 @@ package server
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"goyco/internal/config"
@@ -105,9 +106,9 @@ func defaultRateLimitConfig() config.RateLimitConfig {
return testutils.AppTestConfig.RateLimit
}
func TestAPIRootRouting(t *testing.T) {
func createDefaultRouterConfig() RouterConfig {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
return RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
@@ -115,7 +116,15 @@ func TestAPIRootRouting(t *testing.T) {
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
}
}
func createTestRouter(cfg RouterConfig) http.Handler {
return NewRouter(cfg)
}
func TestAPIRootRouting(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
@@ -141,23 +150,23 @@ func TestAPIRootRouting(t *testing.T) {
}
func TestProtectedRoutesRequireAuth(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
router := createTestRouter(createDefaultRouterConfig())
protectedRoutes := []struct {
method string
path string
}{
{http.MethodGet, "/api/auth/me"},
{http.MethodPost, "/api/auth/logout"},
{http.MethodPost, "/api/auth/revoke"},
{http.MethodPost, "/api/auth/revoke-all"},
{http.MethodPut, "/api/auth/email"},
{http.MethodPut, "/api/auth/username"},
{http.MethodPut, "/api/auth/password"},
{http.MethodDelete, "/api/auth/account"},
{http.MethodPost, "/api/posts"},
{http.MethodPut, "/api/posts/1"},
{http.MethodDelete, "/api/posts/1"},
{http.MethodPost, "/api/posts/1/vote"},
{http.MethodDelete, "/api/posts/1/vote"},
{http.MethodGet, "/api/posts/1/vote"},
@@ -183,17 +192,9 @@ func TestProtectedRoutesRequireAuth(t *testing.T) {
}
func TestRouterWithDebugMode(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
Debug: true,
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.Debug = true
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -206,16 +207,9 @@ func TestRouterWithDebugMode(t *testing.T) {
}
func TestRouterWithCacheDisabled(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
DisableCache: true,
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
})
cfg := createDefaultRouterConfig()
cfg.DisableCache = true
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -228,17 +222,9 @@ func TestRouterWithCacheDisabled(t *testing.T) {
}
func TestRouterWithCompressionDisabled(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
DisableCompression: true,
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.DisableCompression = true
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -251,19 +237,9 @@ func TestRouterWithCompressionDisabled(t *testing.T) {
}
func TestRouterWithCustomDBMonitor(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
customDBMonitor := middleware.NewInMemoryDBMonitor()
router := NewRouter(RouterConfig{
DBMonitor: customDBMonitor,
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.DBMonitor = middleware.NewInMemoryDBMonitor()
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -296,18 +272,9 @@ func TestRouterWithPageHandler(t *testing.T) {
}
func TestRouterWithStaticDir(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
StaticDir: "/custom/static/path",
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.StaticDir = "/custom/static/path"
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -320,18 +287,9 @@ func TestRouterWithStaticDir(t *testing.T) {
}
func TestRouterWithEmptyStaticDir(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
StaticDir: "",
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.StaticDir = ""
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -344,20 +302,11 @@ func TestRouterWithEmptyStaticDir(t *testing.T) {
}
func TestRouterWithAllFeaturesDisabled(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
Debug: true,
DisableCache: true,
DisableCompression: true,
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.Debug = true
cfg.DisableCache = true
cfg.DisableCompression = true
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -370,15 +319,9 @@ func TestRouterWithAllFeaturesDisabled(t *testing.T) {
}
func TestRouterWithoutAPIHandler(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, _, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.APIHandler = nil
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -391,17 +334,7 @@ func TestRouterWithoutAPIHandler(t *testing.T) {
}
func TestRouterWithoutPageHandler(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
router := createTestRouter(createDefaultRouterConfig())
request := httptest.NewRequest(http.MethodGet, "/", nil)
recorder := httptest.NewRecorder()
@@ -414,17 +347,7 @@ func TestRouterWithoutPageHandler(t *testing.T) {
}
func TestSwaggerRoute(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
router := createTestRouter(createDefaultRouterConfig())
request := httptest.NewRequest(http.MethodGet, "/swagger/", nil)
recorder := httptest.NewRecorder()
@@ -437,18 +360,9 @@ func TestSwaggerRoute(t *testing.T) {
}
func TestStaticFileRoute(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
StaticDir: "../../internal/static/",
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.StaticDir = "../../internal/static/"
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/static/css/main.css", nil)
recorder := httptest.NewRecorder()
@@ -461,17 +375,7 @@ func TestStaticFileRoute(t *testing.T) {
}
func TestRouterConfiguration(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
router := createTestRouter(createDefaultRouterConfig())
if router == nil {
t.Error("Router should not be nil")
@@ -487,29 +391,484 @@ func TestRouterConfiguration(t *testing.T) {
}
}
func TestRouterMiddlewareIntegration(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
func TestAllRoutesExist(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
if router == nil {
t.Error("Router should not be nil")
publicRoutes := []struct {
method string
path string
description string
}{
{http.MethodGet, "/api", "API info"},
{http.MethodGet, "/health", "Health check"},
{http.MethodGet, "/metrics", "Metrics"},
{http.MethodGet, "/robots.txt", "Robots.txt"},
{http.MethodGet, "/api/posts", "Get posts"},
{http.MethodGet, "/api/posts/search", "Search posts"},
{http.MethodGet, "/api/posts/title", "Fetch title from URL"},
{http.MethodGet, "/api/posts/1", "Get post by ID"},
{http.MethodPost, "/api/auth/register", "Register"},
{http.MethodPost, "/api/auth/login", "Login"},
{http.MethodPost, "/api/auth/refresh", "Refresh token"},
{http.MethodGet, "/api/auth/confirm", "Confirm email"},
{http.MethodPost, "/api/auth/resend-verification", "Resend verification"},
{http.MethodPost, "/api/auth/forgot-password", "Forgot password"},
{http.MethodPost, "/api/auth/reset-password", "Reset password"},
{http.MethodPost, "/api/auth/account/confirm", "Confirm account deletion"},
}
request := httptest.NewRequest(http.MethodGet, "/api", nil)
protectedRoutes := []struct {
method string
path string
description string
}{
{http.MethodGet, "/api/auth/me", "Get current user"},
{http.MethodPost, "/api/auth/logout", "Logout"},
{http.MethodPost, "/api/auth/revoke", "Revoke token"},
{http.MethodPost, "/api/auth/revoke-all", "Revoke all tokens"},
{http.MethodPut, "/api/auth/email", "Update email"},
{http.MethodPut, "/api/auth/username", "Update username"},
{http.MethodPut, "/api/auth/password", "Update password"},
{http.MethodDelete, "/api/auth/account", "Delete account"},
{http.MethodPost, "/api/posts", "Create post"},
{http.MethodPut, "/api/posts/1", "Update post"},
{http.MethodDelete, "/api/posts/1", "Delete post"},
{http.MethodPost, "/api/posts/1/vote", "Cast vote"},
{http.MethodDelete, "/api/posts/1/vote", "Remove vote"},
{http.MethodGet, "/api/posts/1/vote", "Get user vote"},
{http.MethodGet, "/api/posts/1/votes", "Get post votes"},
{http.MethodGet, "/api/users", "Get users"},
{http.MethodPost, "/api/users", "Create user"},
{http.MethodGet, "/api/users/1", "Get user by ID"},
{http.MethodGet, "/api/users/1/posts", "Get user posts"},
}
for _, route := range publicRoutes {
t.Run(route.description+" "+route.method+" "+route.path, func(t *testing.T) {
invalidMethod := http.MethodPatch
switch route.method {
case http.MethodGet:
invalidMethod = http.MethodDelete
case http.MethodPost:
invalidMethod = http.MethodGet
}
request := httptest.NewRequest(invalidMethod, route.path, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code == 0 {
t.Error("Router should return a status code")
routeExists := recorder.Code == http.StatusMethodNotAllowed
if !routeExists {
request = httptest.NewRequest(route.method, route.path, nil)
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code == http.StatusNotFound && route.path != "/api/posts/1" && route.path != "/robots.txt" {
t.Errorf("Route %s %s should exist, got 404", route.method, route.path)
}
}
})
}
for _, route := range protectedRoutes {
t.Run(route.description+" "+route.method+" "+route.path, func(t *testing.T) {
request := httptest.NewRequest(route.method, route.path, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code == http.StatusNotFound {
t.Errorf("Route %s %s should exist, got 404", route.method, route.path)
}
if recorder.Code != http.StatusUnauthorized {
t.Errorf("Protected route %s %s should return 401 without auth, got %d", route.method, route.path, recorder.Code)
}
})
}
}
func TestRouteParameters(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
method string
pathPattern string
testIDs []string
isProtected bool
}{
{
name: "Get post by ID",
method: http.MethodGet,
pathPattern: "/api/posts/{id}",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: false,
},
{
name: "Update post by ID",
method: http.MethodPut,
pathPattern: "/api/posts/{id}",
testIDs: []string{"1", "42", "999"},
isProtected: true,
},
{
name: "Delete post by ID",
method: http.MethodDelete,
pathPattern: "/api/posts/{id}",
testIDs: []string{"1", "42", "999"},
isProtected: true,
},
{
name: "Get user by ID",
method: http.MethodGet,
pathPattern: "/api/users/{id}",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: true,
},
{
name: "Get user posts by user ID",
method: http.MethodGet,
pathPattern: "/api/users/{id}/posts",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: true,
},
{
name: "Cast vote for post ID",
method: http.MethodPost,
pathPattern: "/api/posts/{id}/vote",
testIDs: []string{"1", "42", "999"},
isProtected: true,
},
{
name: "Remove vote for post ID",
method: http.MethodDelete,
pathPattern: "/api/posts/{id}/vote",
testIDs: []string{"1", "42", "999"},
isProtected: true,
},
{
name: "Get user vote for post ID",
method: http.MethodGet,
pathPattern: "/api/posts/{id}/vote",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: true,
},
{
name: "Get post votes by post ID",
method: http.MethodGet,
pathPattern: "/api/posts/{id}/votes",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, id := range tc.testIDs {
path := replaceID(tc.pathPattern, id)
t.Run("ID_"+id, func(t *testing.T) {
request := httptest.NewRequest(http.MethodPatch, path, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
routeExists := recorder.Code == http.StatusMethodNotAllowed
request = httptest.NewRequest(tc.method, path, nil)
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if !routeExists {
if recorder.Code == http.StatusNotFound {
t.Errorf("Route %s %s should exist with ID %s, got 404", tc.method, path, id)
return
}
}
if tc.isProtected {
if recorder.Code != http.StatusUnauthorized {
t.Errorf("Protected route %s %s should return 401 without auth, got %d", tc.method, path, recorder.Code)
}
}
})
}
})
}
}
func replaceID(pattern, id string) string {
return strings.Replace(pattern, "{id}", id, 1)
}
func TestInvalidRouteParameters(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
method string
path string
expectedMin int
expectedMax int
isProtected bool
allow401 bool
}{
{
name: "Non-numeric post ID",
method: http.MethodGet,
path: "/api/posts/abc",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusBadRequest,
isProtected: false,
},
{
name: "Negative post ID",
method: http.MethodGet,
path: "/api/posts/-1",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusBadRequest,
isProtected: false,
},
{
name: "Zero post ID",
method: http.MethodGet,
path: "/api/posts/0",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusNotFound,
isProtected: false,
},
{
name: "Post ID with special characters",
method: http.MethodGet,
path: "/api/posts/123@456",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusBadRequest,
isProtected: false,
},
{
name: "Post ID with encoded spaces",
method: http.MethodGet,
path: "/api/posts/12%2034",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusBadRequest,
isProtected: false,
},
{
name: "Non-numeric user ID",
method: http.MethodGet,
path: "/api/users/xyz",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusUnauthorized,
isProtected: true,
allow401: true,
},
{
name: "Negative user ID",
method: http.MethodGet,
path: "/api/users/-5",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusUnauthorized,
isProtected: true,
allow401: true,
},
{
name: "Non-numeric post ID in vote route",
method: http.MethodGet,
path: "/api/posts/invalid/vote",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusUnauthorized,
isProtected: true,
allow401: true,
},
{
name: "Very large post ID",
method: http.MethodGet,
path: "/api/posts/999999999999",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusNotFound,
isProtected: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
request := httptest.NewRequest(tc.method, tc.path, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if tc.isProtected && tc.allow401 {
if recorder.Code != http.StatusUnauthorized && (recorder.Code < tc.expectedMin || recorder.Code > tc.expectedMax) {
t.Errorf("Protected route %s %s with invalid parameter should return 401 or status between %d and %d, got %d", tc.method, tc.path, tc.expectedMin, tc.expectedMax, recorder.Code)
}
} else {
if recorder.Code < tc.expectedMin || recorder.Code > tc.expectedMax {
t.Errorf("Route %s %s should return status between %d and %d, got %d", tc.method, tc.path, tc.expectedMin, tc.expectedMax, recorder.Code)
}
if recorder.Code != http.StatusNotFound && recorder.Code < 400 {
t.Errorf("Route %s %s with invalid parameter should return error status (4xx), got %d", tc.method, tc.path, recorder.Code)
}
}
})
}
}
func TestQueryParameters(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
method string
path string
queryParams string
expectRoute bool
}{
{
name: "Get posts with limit and offset",
method: http.MethodGet,
path: "/api/posts",
queryParams: "limit=10&offset=5",
expectRoute: true,
},
{
name: "Get posts with only limit",
method: http.MethodGet,
path: "/api/posts",
queryParams: "limit=20",
expectRoute: true,
},
{
name: "Get posts with only offset",
method: http.MethodGet,
path: "/api/posts",
queryParams: "offset=10",
expectRoute: true,
},
{
name: "Search posts with query parameter",
method: http.MethodGet,
path: "/api/posts/search",
queryParams: "q=test",
expectRoute: true,
},
{
name: "Search posts with query, limit, and offset",
method: http.MethodGet,
path: "/api/posts/search",
queryParams: "q=test&limit=15&offset=3",
expectRoute: true,
},
{
name: "Fetch title with URL parameter",
method: http.MethodGet,
path: "/api/posts/title",
queryParams: "url=https://example.com",
expectRoute: true,
},
{
name: "Confirm email with token parameter",
method: http.MethodGet,
path: "/api/auth/confirm",
queryParams: "token=abc123",
expectRoute: true,
},
{
name: "Get posts with invalid limit",
method: http.MethodGet,
path: "/api/posts",
queryParams: "limit=abc",
expectRoute: true,
},
{
name: "Get posts with negative limit",
method: http.MethodGet,
path: "/api/posts",
queryParams: "limit=-5",
expectRoute: true,
},
{
name: "Get posts with negative offset",
method: http.MethodGet,
path: "/api/posts",
queryParams: "offset=-10",
expectRoute: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
fullPath := tc.path
if tc.queryParams != "" {
fullPath += "?" + tc.queryParams
}
request := httptest.NewRequest(tc.method, fullPath, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if tc.expectRoute {
if recorder.Code == http.StatusNotFound {
t.Errorf("Route %s %s should exist with query parameters, got 404", tc.method, fullPath)
}
}
})
}
}
func TestRouteConflicts(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
method string
path string
description string
}{
{
name: "posts/search should not match posts/{id}",
method: http.MethodGet,
path: "/api/posts/search",
description: "search route should be matched, not treated as ID",
},
{
name: "posts/title should not match posts/{id}",
method: http.MethodGet,
path: "/api/posts/title",
description: "title route should be matched, not treated as ID",
},
{
name: "posts/{id} should work with numeric ID",
method: http.MethodGet,
path: "/api/posts/123",
description: "numeric ID should match {id} route",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
request := httptest.NewRequest(tc.method, tc.path, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
switch tc.path {
case "/api/posts/search":
if recorder.Code == http.StatusNotFound {
t.Errorf("%s: Route %s %s should exist (not 404), got %d", tc.description, tc.method, tc.path, recorder.Code)
}
case "/api/posts/title":
if recorder.Code == http.StatusNotFound {
t.Errorf("%s: Route %s %s should exist (not 404), got %d", tc.description, tc.method, tc.path, recorder.Code)
}
case "/api/posts/123":
if recorder.Code == http.StatusNotFound {
return
}
if recorder.Code < 400 {
t.Errorf("%s: Route %s %s should return 4xx or 5xx, got %d", tc.description, tc.method, tc.path, recorder.Code)
}
}
})
}
}

View File

@@ -1,12 +1,93 @@
package services
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"net/mail"
"strings"
"goyco/internal/config"
"goyco/internal/database"
)
const (
defaultTokenExpirationHours = 24
verificationTokenBytes = 32
deletionTokenExpirationHours = 24
)
var (
ErrInvalidCredentials = errors.New("invalid credentials")
ErrInvalidToken = errors.New("invalid or expired token")
ErrUsernameTaken = errors.New("username already exists")
ErrEmailTaken = errors.New("email already exists")
ErrInvalidEmail = errors.New("invalid email address")
ErrPasswordTooShort = errors.New("password too short")
ErrEmailNotVerified = errors.New("email not verified")
ErrAccountLocked = errors.New("account is locked")
ErrInvalidVerificationToken = errors.New("invalid verification token")
ErrEmailSenderUnavailable = errors.New("email sender not configured")
ErrDeletionEmailFailed = errors.New("account deletion email failed")
ErrInvalidDeletionToken = errors.New("invalid account deletion token")
ErrUserNotFound = errors.New("user not found")
ErrDeletionRequestNotFound = errors.New("deletion request not found")
)
type AuthResult struct {
AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
RefreshToken string `json:"refresh_token" example:"a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6q7r8s9t0"`
User *database.User `json:"user"`
}
type RegistrationResult struct {
User *database.User `json:"user"`
VerificationSent bool `json:"verification_sent"`
}
func normalizeEmail(email string) (string, error) {
trimmed := strings.TrimSpace(email)
if trimmed == "" {
return "", fmt.Errorf("email is required")
}
parsed, err := mail.ParseAddress(trimmed)
if err != nil {
return "", ErrInvalidEmail
}
return strings.ToLower(parsed.Address), nil
}
func generateVerificationToken() (string, string, error) {
buf := make([]byte, verificationTokenBytes)
if _, err := rand.Read(buf); err != nil {
return "", "", fmt.Errorf("generate verification token: %w", err)
}
token := hex.EncodeToString(buf)
hashed := HashVerificationToken(token)
return token, hashed, nil
}
func HashVerificationToken(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}
func sanitizeUser(user *database.User) *database.User {
if user == nil {
return nil
}
copy := *user
copy.Password = ""
copy.EmailVerificationToken = ""
return &copy
}
type AuthFacade struct {
registrationService *RegistrationService
passwordResetService *PasswordResetService

View File

@@ -1,35 +0,0 @@
package services
import (
"errors"
"goyco/internal/database"
)
var (
ErrInvalidCredentials = errors.New("invalid credentials")
ErrInvalidToken = errors.New("invalid or expired token")
ErrUsernameTaken = errors.New("username already exists")
ErrEmailTaken = errors.New("email already exists")
ErrInvalidEmail = errors.New("invalid email address")
ErrPasswordTooShort = errors.New("password too short")
ErrEmailNotVerified = errors.New("email not verified")
ErrAccountLocked = errors.New("account is locked")
ErrInvalidVerificationToken = errors.New("invalid verification token")
ErrEmailSenderUnavailable = errors.New("email sender not configured")
ErrDeletionEmailFailed = errors.New("account deletion email failed")
ErrInvalidDeletionToken = errors.New("invalid account deletion token")
ErrUserNotFound = errors.New("user not found")
ErrDeletionRequestNotFound = errors.New("deletion request not found")
)
type AuthResult struct {
AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
RefreshToken string `json:"refresh_token" example:"a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6q7r8s9t0"`
User *database.User `json:"user"`
}
type RegistrationResult struct {
User *database.User `json:"user"`
VerificationSent bool `json:"verification_sent"`
}

View File

@@ -1,59 +0,0 @@
package services
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/mail"
"strings"
"goyco/internal/database"
)
const (
defaultTokenExpirationHours = 24
verificationTokenBytes = 32
deletionTokenExpirationHours = 24
)
func normalizeEmail(email string) (string, error) {
trimmed := strings.TrimSpace(email)
if trimmed == "" {
return "", fmt.Errorf("email is required")
}
parsed, err := mail.ParseAddress(trimmed)
if err != nil {
return "", ErrInvalidEmail
}
return strings.ToLower(parsed.Address), nil
}
func generateVerificationToken() (string, string, error) {
buf := make([]byte, verificationTokenBytes)
if _, err := rand.Read(buf); err != nil {
return "", "", fmt.Errorf("generate verification token: %w", err)
}
token := hex.EncodeToString(buf)
hashed := HashVerificationToken(token)
return token, hashed, nil
}
func HashVerificationToken(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}
func sanitizeUser(user *database.User) *database.User {
if user == nil {
return nil
}
copy := *user
copy.Password = ""
copy.EmailVerificationToken = ""
return &copy
}

View File

@@ -7,9 +7,10 @@ import (
"sync"
"testing"
"gorm.io/gorm"
"goyco/internal/database"
"goyco/internal/repositories"
"gorm.io/gorm"
)
type mockVoteRepo struct {
@@ -240,6 +241,25 @@ func (m *mockVoteRepo) CountByUserID(userID uint) (int64, error) {
return count, nil
}
func (m *mockVoteRepo) GetVoteCountsByPostID(postID uint) (int, int, error) {
m.mu.RLock()
defer m.mu.RUnlock()
upVotes := 0
downVotes := 0
for _, vote := range m.votes {
if vote.PostID == postID {
switch vote.Type {
case database.VoteUp:
upVotes++
case database.VoteDown:
downVotes++
}
}
}
return upVotes, downVotes, nil
}
func (m *mockVoteRepo) WithTx(tx *gorm.DB) repositories.VoteRepository {
return m
}

View File

@@ -2,43 +2,70 @@ package templates
import (
"html/template"
"io/fs"
"path/filepath"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestTemplateParsing(t *testing.T) {
templateDir := "./"
func templateFuncMap() template.FuncMap {
return template.FuncMap{
"formatTime": func(t time.Time) string {
if t.IsZero() {
return ""
}
return t.Format("02 Jan 2006 15:04")
},
"truncate": func(s string, length int) string {
if len(s) <= length {
return s
}
if length <= 3 {
return s[:length]
}
return s[:length-3] + "..."
},
"substr": func(s string, start, length int) string {
if start >= len(s) {
return ""
}
end := min(start+length, len(s))
return s[start:end]
},
"upper": strings.ToUpper,
}
}
var templateFiles []string
err := filepath.WalkDir(templateDir, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if !d.IsDir() && filepath.Ext(path) == ".gohtml" {
templateFiles = append(templateFiles, path)
}
return nil
})
func TestTemplateParsing(t *testing.T) {
layoutPath := filepath.Join(".", "base.gohtml")
require.FileExists(t, layoutPath, "base layout is required for all templates")
partials, err := filepath.Glob(filepath.Join(".", "partials", "*.gohtml"))
require.NoError(t, err)
tmpl := template.New("test")
pages, err := filepath.Glob(filepath.Join(".", "*.gohtml"))
require.NoError(t, err)
require.NotEmpty(t, pages, "no page templates found")
tmpl = tmpl.Funcs(template.FuncMap{
"formatTime": func(any) string { return "2024-01-01" },
"eq": func(a, b any) bool { return a == b },
"ne": func(a, b any) bool { return a != b },
"len": func(s any) int { return 0 },
"range": func(s any) any { return s },
})
for _, page := range pages {
if filepath.Base(page) == "base.gohtml" {
continue
}
for _, file := range templateFiles {
t.Run(file, func(t *testing.T) {
_, err := tmpl.ParseFiles(file)
assert.NoError(t, err, "Template %s should parse without errors", file)
page := page
t.Run(filepath.Base(page), func(t *testing.T) {
t.Parallel()
files := append([]string{layoutPath}, partials...)
files = append(files, page)
tmpl, err := template.New(filepath.Base(page)).Funcs(templateFuncMap()).ParseFiles(files...)
require.NoError(t, err)
require.NotNil(t, tmpl.Lookup("layout"), "layout template should be available")
require.NotNil(t, tmpl.Lookup("content"), "content block should be defined by page templates")
})
}
}

View File

@@ -422,6 +422,24 @@ func (m *MockUserRepository) GetDeletedUsers() ([]database.User, error) {
return []database.User{}, nil
}
func (m *MockUserRepository) GetByUsernamePrefix(prefix string) (*database.User, error) {
m.mu.RLock()
defer m.mu.RUnlock()
for _, user := range m.users {
if len(user.Username) >= len(prefix) && user.Username[:len(prefix)] == prefix {
if len(user.Email) >= 13 && strings.HasSuffix(user.Email, "@goyco.local") {
emailPrefix := user.Email[:len(user.Email)-13]
if len(emailPrefix) >= len(prefix) && emailPrefix[:len(prefix)] == prefix {
return user, nil
}
}
}
}
return nil, gorm.ErrRecordNotFound
}
func (m *MockUserRepository) HardDeleteAll() (int64, error) {
if m.HardDeleteAllFunc != nil {
return m.HardDeleteAllFunc()
@@ -965,6 +983,25 @@ func (m *MockVoteRepository) CountByUserID(userID uint) (int64, error) {
return count, nil
}
func (m *MockVoteRepository) GetVoteCountsByPostID(postID uint) (int, int, error) {
m.mu.RLock()
defer m.mu.RUnlock()
upVotes := 0
downVotes := 0
for _, vote := range m.votes {
if vote.PostID == postID {
switch vote.Type {
case database.VoteUp:
upVotes++
case database.VoteDown:
downVotes++
}
}
}
return upVotes, downVotes, nil
}
func (m *MockVoteRepository) Count() (int64, error) {
m.mu.RLock()
defer m.mu.RUnlock()

View File

@@ -153,6 +153,7 @@ type UserRepositoryStub struct {
UnlockFn func(uint) error
GetPostsFn func(uint, int, int) ([]database.Post, error)
GetDeletedUsersFn func() ([]database.User, error)
GetByUsernamePrefixFn func(string) (*database.User, error)
HardDeleteAllFn func() (int64, error)
CountFn func() (int64, error)
WithTxFn func(*gorm.DB) repositories.UserRepository
@@ -281,6 +282,13 @@ func (s *UserRepositoryStub) GetDeletedUsers() ([]database.User, error) {
return nil, nil
}
func (s *UserRepositoryStub) GetByUsernamePrefix(prefix string) (*database.User, error) {
if s != nil && s.GetByUsernamePrefixFn != nil {
return s.GetByUsernamePrefixFn(prefix)
}
return nil, gorm.ErrRecordNotFound
}
func (s *UserRepositoryStub) HardDeleteAll() (int64, error) {
if s != nil && s.HardDeleteAllFn != nil {
return s.HardDeleteAllFn()

View File

@@ -30,7 +30,7 @@ if [ ! -d "docs" ]; then
mkdir -p docs
fi
SWAGGER_DIRECTORIES="cmd/goyco,internal/handlers"
SWAGGER_DIRECTORIES="cmd/goyco,internal/handlers,internal/dto"
SWAGGER_MAIN_FILE="main.go"
SWAGGER_OUTPUT_DIR="docs"

View File

@@ -44,5 +44,3 @@ GRANT ALL PRIVILEGES ON DATABASE goyco TO goyco;
EOF
echo "PostgreSQL 18 installed, database 'goyco' and user 'goyco' set up."