Compare commits

...

96 Commits

Author SHA1 Message Date
07c6b89525 docs: remove project structure, boring and hard to maintain 2025-12-20 17:35:16 +01:00
817205d42f refactor: modernize using min() 2025-12-16 15:45:51 +01:00
199ac143a4 refactor: replace interface{} by any 2025-12-16 15:05:23 +01:00
aa7e259ed0 format: shfmt 2025-12-16 15:02:42 +01:00
4587609e17 refactor: create createTestRouter and test edge cases 2025-12-14 21:14:42 +01:00
33da6503e3 test: also test put/delete routes 2025-12-14 21:06:15 +01:00
cafc44ed77 test: add a test for route parameters 2025-12-14 21:04:36 +01:00
1480135e75 test: verified all routes to exist 2025-12-14 21:02:25 +01:00
02a764c736 clean: remove merged files 2025-12-14 20:52:14 +01:00
6834ad7764 refactor: merge facade, types and utils into one auth_service.go 2025-12-14 20:52:03 +01:00
dcf054046f test: fix parallel processor test expectations and setup 2025-12-10 07:30:27 +01:00
d2a788933d fix: track completed items in the main loop instead of using the index 2025-12-10 07:29:55 +01:00
18be3950dc clean: obsolete function 2025-12-09 22:07:30 +01:00
f9cb140e95 clean: removed the obsolete functions outputMessage and outputError 2025-12-09 22:06:12 +01:00
86d4835ccf feat: seed user is now uniq 2025-12-09 22:03:26 +01:00
feddb2ed43 test: new unit test for EnsureSeedUser 2025-12-09 22:03:16 +01:00
457b5c88e2 refactor: improve seed consistency validation 2025-12-09 21:53:03 +01:00
a8d363b2bf fix: templates now parse with the same func map as the page handler 2025-12-09 21:37:21 +01:00
0cd68e847c refactor: add a helper to centralize CSRF token retrieval 2025-12-09 15:58:28 +01:00
df6aeed713 docs: unocss it is 2025-12-04 20:43:30 +01:00
785faeb60c feat: update alpine to 3.23 2025-12-04 10:01:19 +01:00
0623c027ba docs: prepare CONTRIBUTING.md 2025-12-03 20:57:32 +01:00
d4e91b6034 refactor: complete refactor and better helpers use 2025-11-29 15:19:41 +01:00
7d46d3e81b clean: remove the unused expectedValue in assertHeader (always set to "") 2025-11-29 15:19:28 +01:00
216aaf3117 refactor: clean code and use new request helpers 2025-11-29 14:58:52 +01:00
435047ad0c refactor: clean code 2025-11-29 14:58:37 +01:00
b7ee8bd11d refactor: clean variable names and use new request helpers 2025-11-29 14:58:20 +01:00
040cd48be8 refactor: clean variables 2025-11-29 14:58:07 +01:00
2dd16e0e00 refactor: complete 2025-11-29 14:56:18 +01:00
d6db70cc79 refactor: clean code and variables, use new request helpers 2025-11-29 14:55:47 +01:00
58e10ade7d refactor: clean variable names and modernize code 2025-11-29 14:50:35 +01:00
7403a75d8e refactor: clean variable naming 2025-11-29 14:46:26 +01:00
b429bc11af refactor: clean code and use new request helpers 2025-11-29 14:41:38 +01:00
2ec5c28fb5 refactor: rename variables and clean code 2025-11-29 14:37:18 +01:00
3743a99e40 refactor: req -> request, rec -> recorder, reqBody -> requestBody... 2025-11-29 14:21:07 +01:00
5710921b87 refactor: use new request helpers 2025-11-29 14:17:25 +01:00
84d9c81484 refactor: rec -> recorder, req -> request and modernize loop 2025-11-29 14:15:07 +01:00
b0c2038927 feat: add new helpers to make requests properly in integration tests 2025-11-29 14:11:32 +01:00
fd88931146 refactor: variable names and modernize loop 2025-11-25 14:37:59 +01:00
6ce0f4dfad refactor: name variables 2025-11-25 10:13:10 +01:00
68e3dceefc refactor: name variables 2025-11-25 10:08:48 +01:00
cded14c526 fix: force correct mime types for static files after modifying compression middleware's buffering 2025-11-24 07:53:08 +01:00
cfca668ca6 docs: clean readme 2025-11-23 21:56:53 +01:00
279255b587 fix: don't let rate limiting fails the test 2025-11-23 21:48:11 +01:00
b83f8c2228 fix: update ValidationMiddleware to return a JSON error response when JSON decoding fails 2025-11-23 21:42:27 +01:00
aabc48128c fix: use router in handlers integration tests (for dto validation) 2025-11-23 15:10:55 +01:00
68716b977b fix: verify XSS sanitization in handler response instead of repository stub 2025-11-23 15:01:54 +01:00
dbe1600632 fix: indentation 2025-11-23 14:49:31 +01:00
458e25cf79 fix: modify compression middleware to pass through redirects immediately without buffering 2025-11-23 14:48:59 +01:00
d4595d8dbf fix: properly encoding the flash message in the redirect URL 2025-11-23 14:48:39 +01:00
c5418f4e4c docs: update swagger 2025-11-23 14:26:52 +01:00
db0369225e refactor: update references to VoteRequest 2025-11-23 14:26:45 +01:00
07ac965b3d refactor: use consistent naming (VoteRequest -> CastVoteRequest) 2025-11-23 14:26:19 +01:00
e2e5d42035 feat: add SetValidatedDTOInContext to support test helper functions 2025-11-23 14:22:59 +01:00
6e4b41894f fix: update test cases to use createCreatePostRequests 2025-11-23 14:22:35 +01:00
0a8ed2e27c fix: add explicite validation check for empty url, title and content length 2025-11-23 14:21:30 +01:00
216e8657f6 feat: add generic createRequestWithDTO along with helpers functions 2025-11-23 14:20:59 +01:00
fb7206c0a2 fix: test context handling 2025-11-23 14:20:09 +01:00
c25926514b fix: add explicit empty-field validation check in handlers 2025-11-23 14:19:54 +01:00
964785e494 docs: update swagger 2025-11-23 13:47:38 +01:00
9c67cd2a47 feat: update vote handler to use dto VoteRequest and update MountRoutes 2025-11-23 13:47:31 +01:00
8b5cc8e939 feat: add VoteRequest with its validation types 2025-11-23 13:46:51 +01:00
0e71b28615 feat: update CreateUser to use dto.RegisterRequest and update MountRoutes to apply validation middleware 2025-11-23 13:43:47 +01:00
cd740da57a feat: update methods to use validated DTOs and update MountRoutes 2025-11-23 13:43:14 +01:00
abe4a3dc88 feat: update handlers to use GetValidatedDTO instead of manual decoding and update MountRoutes to wrap handlers with WithValidation for all DTO-based routes 2025-11-23 13:42:52 +01:00
738243d945 feat: add ValidationMiddleware to RouteModuleConfig 2025-11-23 13:41:55 +01:00
4fbdfb6e4a feat: add two helpers function to retrieve validated DTOs from request context and to apply validation middleware 2025-11-23 13:41:07 +01:00
6bb3a78b88 feat: Add ValidationMiddleware to router configuration 2025-11-23 13:40:31 +01:00
54e37e59fc docs: update swagger 2025-11-23 13:35:00 +01:00
5d4b38ddc4 feat: add validation tags to request DTOs 2025-11-23 13:34:53 +01:00
7dc119ecde docs: update swagger 2025-11-23 13:17:14 +01:00
52c9f4a02b feat: add internal/dto to swagger directories 2025-11-23 13:16:44 +01:00
be91a135bc clean: empty line 2025-11-23 13:14:41 +01:00
2d7ff9778b feat: update swagger comments following dtos relocation 2025-11-23 13:14:07 +01:00
4ff3fd3583 refactor: remove UpdatePostRequest definition and update swagger comments 2025-11-23 13:13:53 +01:00
73121cad15 refactor: remove all request DTO, update swagger comments and update token related methods to use dto ones 2025-11-23 13:13:23 +01:00
c5bf1b2fd8 feat: locate post-related request DTOs 2025-11-23 13:12:36 +01:00
eedebe60d1 feat: locate auth-related request DTOs 2025-11-23 13:12:10 +01:00
80fb37371f update: fix go version and update alpine to 3.22 2025-11-23 10:44:28 +01:00
fea49fad8d fix: add missing method to mock 2025-11-21 17:07:26 +01:00
4b04461ebb style: minor formatting adjustments 2025-11-21 17:06:04 +01:00
533e8c3d46 feat: add GetByUsernamePrefixFn field and method to UserRepositoryStub 2025-11-21 17:05:48 +01:00
df568291f1 feat: add GetByUsernamePrefix implementation to MockUserRepository 2025-11-21 17:05:31 +01:00
81acce62b1 feat: add GetByUsernamePrefix method to interface and add implementation 2025-11-21 17:05:01 +01:00
989a61e7d5 feat: use getByUsernamePrefix to optimize findExistingSeedUser() 2025-11-21 17:04:35 +01:00
3ffd83b0fb feat: ignore docs in make format 2025-11-21 17:02:06 +01:00
62d466e4fa refactor: use go generics 2025-11-21 16:56:26 +01:00
0cd428d5d9 feat: use connection pooling instead of a single connection 2025-11-21 16:53:46 +01:00
5c239ad61d feat: add missing GetVoteCountsByPostID method to the errorVoteRepository test mock 2025-11-21 16:50:23 +01:00
01f2b1fe75 feat: remove loop and use GetVoteCountsByPostID 2025-11-21 16:48:48 +01:00
28134c101c feat: add GetVoteCountsByPostID to the mock for testing 2025-11-21 16:48:15 +01:00
2f78370d43 feat: GetVoteCountsByPostID: use a single sql query to returns up votes and down votes counts 2025-11-21 16:47:52 +01:00
39598a166d feat: remove redundat getbyemail call to reduce db query by 2 (1Q/user creation instead of 2) 2025-11-21 16:43:46 +01:00
fa9474d863 revert: db transaction use, avoiding the pgx RETURNING issue while maintaining data consistency 2025-11-21 16:31:06 +01:00
34a97994b3 feat: improve testing to use production code paths and better coverage 2025-11-21 16:26:21 +01:00
eb5f93ffd0 clean: remove duplicate sequential helpers 2025-11-21 16:25:27 +01:00
65 changed files with 3978 additions and 3715 deletions

1
.prettierignore Normal file
View File

@@ -0,0 +1 @@
docs/

View File

@@ -1,4 +1,4 @@
ARG GO_VERSION=1.25.3
ARG GO_VERSION=1.25.4
# Building the binary using a golang alpine image
FROM golang:${GO_VERSION}-alpine AS go-builder
@@ -11,7 +11,7 @@ ARG TARGETARCH=amd64
RUN CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build -ldflags="-s -w" -o /out/goyco ./cmd/goyco
# building the application image
FROM alpine:3.21
FROM alpine:3.23
RUN addgroup -S goyco && adduser -S -G goyco goyco \
&& apk add --no-cache ca-certificates tzdata
WORKDIR /app

View File

@@ -44,8 +44,8 @@ clean:
rm -fr dist/*
format:
$(PRETTIER) -w .
$(GO) fmt ./...
$(PRETTIER) -w . --ignore-path .prettierignore
$(GO) fmt $(shell $(GO) list ./... | grep -v 'docs')
lint:
$(GOLANGCI_LINT) run

View File

@@ -354,44 +354,6 @@ It will start the application in development mode. You can also run it as a daem
Then, use `./bin/goyco` to manage the application and notably to seed the database with sample data.
### Project Structure
```bash
goyco/
├── bin/ # Compiled binaries (created after build)
├── cmd/
│ └── goyco/ # Main CLI application entrypoint
├── docker/ # Docker Compose & related files
├── docs/ # Documentation and API specs
├── internal/
│ ├── config/ # Configuration management
│ ├── database/ # Database models and access
│ ├── dto/ # Data Transfer Objects (DTOs)
│ ├── e2e/ # End-to-end tests
│ ├── fuzz/ # Fuzz tests
│ ├── handlers/ # HTTP handlers
│ ├── integration/ # Integration tests
│ ├── middleware/ # HTTP middleware
│ ├── repositories/ # Data access layer
│ ├── security/ # Security and auth logic
│ ├── server/ # HTTP server implementation
│ ├── services/ # Business logic
│ ├── static/ # Static web assets
│ ├── templates/ # HTML templates
│ ├── testutils/ # Test helpers/utilities
│ ├── validation/ # Input validation
│ └── version/ # Version information
├── scripts/ # Utility/maintenance scripts
├── services/
│ └── goyco.service # Systemd service unit example
├── .env.example # Environment variable example
├── AUTHORS # Authors file
├── Dockerfile # Docker build file
├── LICENSE # License file
├── Makefile # Project build/test targets
└── README.md # This file
```
### Testing
```bash
@@ -437,31 +399,10 @@ This will regenerate the swagger documentation and update the `docs/swagger.json
- [ ] add right management within the app
- [ ] add an admin backoffice to manage rights, users, content and settings
- [ ] add a way to run read-only communities
- [ ] maybe use a css framework instead of raw css
- [ ] migrate raw CSS to UnoCSS
- [ ] kubernetes deployment
- [ ] store configuration in the database
## Contributing
Feedbacks are welcome!
But as it's a personal gitea and you cannot create accounts, feel free to contact me at <sandro@cazzaniga.fr> to get one.
Once you have it, follow the usual workflow:
1. Fork the repository
2. Create a feature branch
3. Make your changes
4. Add tests for new functionality
5. Ensure all tests pass
6. Submit a pull request
Then, I'll review your changes and merge them if they are good.
## License
This project is licensed under the GNU General Public License v3.0 or later (GPLv3+). See the [LICENSE](LICENSE) file for details.
---
**Goyco** - A modern news aggregation platform built with Go, PostgreSQL and most importantly, love.

View File

@@ -8,9 +8,10 @@ import (
"os"
"sync"
"gorm.io/gorm"
"goyco/internal/config"
"goyco/internal/database"
"gorm.io/gorm"
)
var ErrHelpRequested = errors.New("help requested")
@@ -40,11 +41,11 @@ var (
)
func defaultDBConnector(cfg *config.Config) (*gorm.DB, func() error, error) {
db, err := database.Connect(cfg)
poolManager, err := database.ConnectWithPool(cfg)
if err != nil {
return nil, nil, err
}
return db, func() error { return database.Close(db) }, nil
return poolManager.GetDB(), func() error { return poolManager.Close() }, nil
}
func SetDBConnector(connector DBConnector) {
@@ -118,26 +119,6 @@ func outputJSON(v interface{}) error {
return encoder.Encode(v)
}
func outputMessage(message string, args ...interface{}) {
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"message": fmt.Sprintf(message, args...),
})
} else {
fmt.Printf(message+"\n", args...)
}
}
func outputError(message string, args ...interface{}) {
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"error": fmt.Sprintf(message, args...),
})
} else {
fmt.Fprintf(os.Stderr, message+"\n", args...)
}
}
func outputWarning(message string, args ...interface{}) {
if IsJSONOutput() {
outputJSON(map[string]interface{}{

View File

@@ -5,9 +5,10 @@ import (
"fmt"
"os"
"gorm.io/gorm"
"goyco/internal/config"
"goyco/internal/database"
"gorm.io/gorm"
)
func HandleMigrateCommand(cfg *config.Config, name string, args []string) error {
@@ -37,7 +38,7 @@ func runMigrateCommand(db *gorm.DB) error {
return fmt.Errorf("run migrations: %w", err)
}
if IsJSONOutput() {
outputJSON(map[string]interface{}{
outputJSON(map[string]any{
"action": "migrations_applied",
"status": "success",
})

View File

@@ -42,14 +42,23 @@ func (p *ParallelProcessor) SetPasswordHash(hash string) {
p.passwordHash = hash
}
func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepository, count int, progress *ProgressIndicator) ([]database.User, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
type indexedResult[T any] struct {
value T
index int
}
results := make(chan userResult, count)
func processInParallel[T any](
ctx context.Context,
maxWorkers int,
count int,
processor func(index int) (T, error),
errorPrefix string,
progress *ProgressIndicator,
) ([]T, error) {
results := make(chan indexedResult[T], count)
errors := make(chan error, count)
semaphore := make(chan struct{}, p.maxWorkers)
semaphore := make(chan struct{}, maxWorkers)
var wg sync.WaitGroup
for i := range count {
@@ -65,13 +74,13 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
}
defer func() { <-semaphore }()
user, err := p.createSingleUser(userRepo, index+1)
value, err := processor(index + 1)
if err != nil {
errors <- fmt.Errorf("create user %d: %w", index+1, err)
errors <- fmt.Errorf("%s %d: %w", errorPrefix, index+1, err)
return
}
results <- userResult{user: user, index: index}
results <- indexedResult[T]{value: value, index: index}
}(i)
}
@@ -81,7 +90,7 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
close(errors)
}()
users := make([]database.User, count)
items := make([]T, count)
completed := 0
firstError := make(chan error, 1)
@@ -101,9 +110,9 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
select {
case result, ok := <-results:
if !ok {
return users, nil
return items, nil
}
users[result.index] = result.user
items[result.index] = result.value
completed++
if progress != nil {
progress.Update(completed)
@@ -111,26 +120,59 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
case err := <-firstError:
return nil, err
case <-ctx.Done():
return nil, fmt.Errorf("timeout creating users: %w", ctx.Err())
return nil, fmt.Errorf("timeout: %w", ctx.Err())
}
}
return users, nil
return items, nil
}
func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepository, count int, progress *ProgressIndicator) ([]database.User, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
return processInParallel(ctx, p.maxWorkers, count,
func(index int) (database.User, error) {
return p.createSingleUser(userRepo, index)
},
"create user",
progress,
)
}
func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepository, authorID uint, count int, progress *ProgressIndicator) ([]database.Post, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
results := make(chan postResult, count)
return processInParallel(ctx, p.maxWorkers, count,
func(index int) (database.Post, error) {
return p.createSinglePost(postRepo, authorID, index)
},
"create post",
progress,
)
}
func processItemsInParallel[T any, R any](
ctx context.Context,
maxWorkers int,
items []T,
processor func(index int, item T) (R, error),
errorPrefix string,
aggregator func(accumulator R, value R) R,
initialValue R,
progress *ProgressIndicator,
) (R, error) {
count := len(items)
results := make(chan indexedResult[R], count)
errors := make(chan error, count)
semaphore := make(chan struct{}, p.maxWorkers)
semaphore := make(chan struct{}, maxWorkers)
var wg sync.WaitGroup
for i := range count {
for i, item := range items {
wg.Add(1)
go func(index int) {
go func(index int, item T) {
defer wg.Done()
select {
@@ -141,14 +183,14 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo
}
defer func() { <-semaphore }()
post, err := p.createSinglePost(postRepo, authorID, index+1)
value, err := processor(index, item)
if err != nil {
errors <- fmt.Errorf("create post %d: %w", index+1, err)
errors <- fmt.Errorf("%s %d: %w", errorPrefix, index+1, err)
return
}
results <- postResult{post: post, index: index}
}(i)
results <- indexedResult[R]{value: value, index: index}
}(i, item)
}
go func() {
@@ -157,7 +199,7 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo
close(errors)
}()
posts := make([]database.Post, count)
accumulator := initialValue
completed := 0
firstError := make(chan error, 1)
@@ -177,36 +219,56 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo
select {
case result, ok := <-results:
if !ok {
return posts, nil
return accumulator, nil
}
posts[result.index] = result.post
accumulator = aggregator(accumulator, result.value)
completed++
if progress != nil {
progress.Update(completed)
}
case err := <-firstError:
return nil, err
return initialValue, err
case <-ctx.Done():
return nil, fmt.Errorf("timeout creating posts: %w", ctx.Err())
return initialValue, fmt.Errorf("timeout: %w", ctx.Err())
}
}
return posts, nil
return accumulator, nil
}
func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post, avgVotesPerPost int, progress *ProgressIndicator) (int, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
results := make(chan voteResult, len(posts))
errors := make(chan error, len(posts))
return processItemsInParallel(ctx, p.maxWorkers, posts,
func(index int, post database.Post) (int, error) {
return p.createVotesForPost(voteRepo, users, post, avgVotesPerPost)
},
"create votes for post",
func(acc, val int) int { return acc + val },
0,
progress,
)
}
semaphore := make(chan struct{}, p.maxWorkers)
func processItemsInParallelNoResult[T any](
ctx context.Context,
maxWorkers int,
items []T,
processor func(index int, item T) error,
errorFormatter func(index int, item T, err error) error,
progress *ProgressIndicator,
) error {
count := len(items)
errors := make(chan error, count)
completions := make(chan struct{}, count)
semaphore := make(chan struct{}, maxWorkers)
var wg sync.WaitGroup
for i, post := range posts {
for i, item := range items {
wg.Add(1)
go func(index int, post database.Post) {
go func(index int, item T) {
defer wg.Done()
select {
@@ -217,23 +279,26 @@ func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepo
}
defer func() { <-semaphore }()
votes, err := p.createVotesForPost(voteRepo, users, post, avgVotesPerPost)
err := processor(index, item)
if err != nil {
errors <- fmt.Errorf("create votes for post %d: %w", post.ID, err)
if errorFormatter != nil {
errors <- errorFormatter(index, item, err)
} else {
errors <- fmt.Errorf("process item %d: %w", index+1, err)
}
return
}
results <- voteResult{votes: votes, index: index}
}(i, post)
completions <- struct{}{}
}(i, item)
}
go func() {
wg.Wait()
close(results)
close(errors)
close(completions)
}()
totalVotes := 0
completed := 0
firstError := make(chan error, 1)
@@ -249,88 +314,39 @@ func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepo
}
}()
for completed < len(posts) {
for completed < count {
select {
case result, ok := <-results:
case _, ok := <-completions:
if !ok {
return totalVotes, nil
return nil
}
totalVotes += result.votes
completed++
if progress != nil {
progress.Update(completed)
}
case err := <-firstError:
return 0, err
case <-ctx.Done():
return 0, fmt.Errorf("timeout creating votes: %w", ctx.Err())
}
}
return totalVotes, nil
}
func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, progress *ProgressIndicator) error {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
errors := make(chan error, len(posts))
semaphore := make(chan struct{}, p.maxWorkers)
var wg sync.WaitGroup
for i, post := range posts {
wg.Add(1)
go func(index int, post database.Post) {
defer wg.Done()
select {
case semaphore <- struct{}{}:
case <-ctx.Done():
errors <- ctx.Err()
return
}
defer func() { <-semaphore }()
err := p.updateSinglePostScore(postRepo, voteRepo, post)
if err != nil {
errors <- fmt.Errorf("update post %d scores: %w", post.ID, err)
return
}
if progress != nil {
progress.Update(index + 1)
}
}(i, post)
}
go func() {
wg.Wait()
close(errors)
}()
for err := range errors {
if err != nil {
return err
case <-ctx.Done():
return fmt.Errorf("timeout: %w", ctx.Err())
}
}
return nil
}
type userResult struct {
user database.User
index int
}
func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, progress *ProgressIndicator) error {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
type postResult struct {
post database.Post
index int
}
type voteResult struct {
votes int
index int
return processItemsInParallelNoResult(ctx, p.maxWorkers, posts,
func(index int, post database.Post) error {
return p.updateSinglePostScore(postRepo, voteRepo, post)
},
func(index int, post database.Post, err error) error {
return fmt.Errorf("update post %d scores: %w", post.ID, err)
},
progress,
)
}
func (p *ParallelProcessor) generateRandomIdentifier() string {
@@ -352,12 +368,7 @@ func (p *ParallelProcessor) createSingleUser(userRepo repositories.UserRepositor
const maxRetries = 10
for range maxRetries {
user, err := userRepo.GetByEmail(email)
if err == nil {
return *user, nil
}
user = &database.User{
user := &database.User{
Username: username,
Email: email,
Password: p.passwordHash,

View File

@@ -2,15 +2,15 @@ package commands_test
import (
"errors"
"fmt"
"sync"
"testing"
"golang.org/x/crypto/bcrypt"
"goyco/cmd/goyco/commands"
"goyco/internal/database"
"goyco/internal/repositories"
"goyco/internal/testutils"
"golang.org/x/crypto/bcrypt"
)
func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
@@ -25,7 +25,7 @@ func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
wantErr bool
}{
{
name: "creates users with deterministic fields",
name: "creates users with required fields",
count: successCount,
repoFactory: func() repositories.UserRepository {
base := testutils.NewMockUserRepository()
@@ -37,14 +37,24 @@ func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
if len(got) != successCount {
t.Fatalf("expected %d users, got %d", successCount, len(got))
}
usernames := make(map[string]bool)
for i, user := range got {
expectedUsername := fmt.Sprintf("user_%d", i+1)
expectedEmail := fmt.Sprintf("user_%d@goyco.local", i+1)
if user.Username != expectedUsername {
t.Errorf("user %d username mismatch: got %q want %q", i, user.Username, expectedUsername)
if user.Username == "" {
t.Errorf("user %d expected non-empty username", i)
}
if user.Email != expectedEmail {
t.Errorf("user %d email mismatch: got %q want %q", i, user.Email, expectedEmail)
if len(user.Username) < 6 || user.Username[:5] != "user_" {
t.Errorf("user %d username should start with 'user_', got %q", i, user.Username)
}
if usernames[user.Username] {
t.Errorf("user %d duplicate username: %q", i, user.Username)
}
usernames[user.Username] = true
if user.Email == "" {
t.Errorf("user %d expected non-empty email", i)
}
if len(user.Email) < 20 || user.Email[:5] != "user_" || user.Email[len(user.Email)-12:] != "@goyco.local" {
t.Errorf("user %d email should match pattern 'user_*@goyco.local', got %q", i, user.Email)
}
if !user.EmailVerified {
t.Errorf("user %d expected EmailVerified to be true", i)
@@ -83,6 +93,11 @@ func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
t.Parallel()
repo := tt.repoFactory()
p := commands.NewParallelProcessor()
passwordHash, err := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
if err != nil {
t.Fatalf("failed to generate password hash: %v", err)
}
p.SetPasswordHash(string(passwordHash))
got, gotErr := p.CreateUsersInParallel(repo, tt.count, tt.progress)
if gotErr != nil {
if !tt.wantErr {

View File

@@ -7,7 +7,6 @@ import (
"fmt"
"math/rand"
"os"
"strings"
"sync"
"time"
@@ -36,17 +35,6 @@ func initSeedRand() {
})
}
func generateRandomIdentifier() string {
initSeedRand()
const length = 12
const chars = "abcdefghijklmnopqrstuvwxyz0123456789"
identifier := make([]byte, length)
for i := range identifier {
identifier[i] = chars[seedRandSource.Intn(len(chars))]
}
return string(identifier)
}
func HandleSeedCommand(cfg *config.Config, name string, args []string) error {
fs := newFlagSet(name, printSeedUsage)
if err := parseCommand(fs, args, name); err != nil {
@@ -57,15 +45,10 @@ func HandleSeedCommand(cfg *config.Config, name string, args []string) error {
}
return withDatabase(cfg, func(db *gorm.DB) error {
return db.Transaction(func(tx *gorm.DB) error {
userRepo := repositories.NewUserRepository(db).WithTx(tx)
postRepo := repositories.NewPostRepository(db).WithTx(tx)
voteRepo := repositories.NewVoteRepository(db).WithTx(tx)
if err := runSeedCommand(userRepo, postRepo, voteRepo, fs.Args()); err != nil {
return err
}
return nil
})
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
voteRepo := repositories.NewVoteRepository(db)
return runSeedCommand(userRepo, postRepo, voteRepo, fs.Args())
})
}
@@ -219,7 +202,7 @@ func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.Po
progress.Complete()
}
if err := validateSeedConsistency(postRepo, voteRepo, allUsers, posts); err != nil {
if err := validateSeedConsistency(voteRepo, allUsers, posts); err != nil {
return fmt.Errorf("seed consistency validation failed: %w", err)
}
@@ -242,285 +225,93 @@ func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.Po
return nil
}
func findExistingSeedUser(userRepo repositories.UserRepository) (*database.User, error) {
users, err := userRepo.GetAll(100, 0)
if err != nil {
return nil, err
}
for _, user := range users {
if len(user.Username) >= 11 && user.Username[:11] == "seed_admin_" {
if len(user.Email) >= 13 && strings.HasSuffix(user.Email, "@goyco.local") {
emailPrefix := user.Email[:len(user.Email)-13]
if len(emailPrefix) >= 11 && emailPrefix[:11] == "seed_admin_" {
return &user, nil
}
}
}
}
return nil, fmt.Errorf("no existing seed user found")
}
const (
seedUsername = "seed_admin"
seedEmail = "seed_admin@goyco.local"
)
func ensureSeedUser(userRepo repositories.UserRepository, passwordHash string) (*database.User, error) {
existingUser, err := findExistingSeedUser(userRepo)
if err == nil && existingUser != nil {
return existingUser, nil
}
randomID := generateRandomIdentifier()
seedUsername := fmt.Sprintf("seed_admin_%s", randomID)
seedEmail := fmt.Sprintf("seed_admin_%s@goyco.local", randomID)
const maxRetries = 10
for range maxRetries {
user, err := userRepo.GetByEmail(seedEmail)
if err == nil {
return user, nil
}
user = &database.User{
Username: seedUsername,
Email: seedEmail,
Password: passwordHash,
EmailVerified: true,
}
if err := userRepo.Create(user); err != nil {
randomID = generateRandomIdentifier()
seedUsername = fmt.Sprintf("seed_admin_%s", randomID)
seedEmail = fmt.Sprintf("seed_admin_%s@goyco.local", randomID)
continue
}
if user, err := userRepo.GetByUsername(seedUsername); err == nil {
return user, nil
}
return nil, fmt.Errorf("failed to create seed user after %d attempts", maxRetries)
}
func createRandomUsers(userRepo repositories.UserRepository, count int, passwordHash string) ([]database.User, error) {
var users []database.User
for i := range count {
username := fmt.Sprintf("user_%d", i+1)
email := fmt.Sprintf("user_%d@goyco.local", i+1)
user := &database.User{
Username: username,
Email: email,
Password: passwordHash,
EmailVerified: true,
}
if err := userRepo.Create(user); err != nil {
return nil, fmt.Errorf("create user %d: %w", i+1, err)
}
users = append(users, *user)
user := &database.User{
Username: seedUsername,
Email: seedEmail,
Password: passwordHash,
EmailVerified: true,
}
return users, nil
}
func createRandomPosts(postRepo repositories.PostRepository, authorID uint, count int) ([]database.Post, error) {
var posts []database.Post
sampleTitles := []string{
"Amazing JavaScript Framework",
"Python Best Practices",
"Go Performance Tips",
"Database Optimization",
"Web Security Guide",
"Machine Learning Basics",
"Cloud Architecture",
"DevOps Automation",
"API Design Patterns",
"Frontend Optimization",
"Backend Scaling",
"Container Orchestration",
"Microservices Architecture",
"Testing Strategies",
"Code Review Process",
"Version Control Best Practices",
"Continuous Integration",
"Monitoring and Alerting",
"Error Handling Patterns",
"Data Structures Explained",
if err := userRepo.Create(user); err != nil {
return nil, fmt.Errorf("failed to create seed user: %w", err)
}
sampleDomains := []string{
"example.com",
"techblog.org",
"devguide.net",
"programming.io",
"codeexamples.com",
"tutorialhub.org",
"bestpractices.dev",
"learnprogramming.net",
"codingtips.org",
"softwareengineering.com",
}
for i := range count {
title := sampleTitles[i%len(sampleTitles)]
if i >= len(sampleTitles) {
title = fmt.Sprintf("%s - Part %d", title, (i/len(sampleTitles))+1)
}
domain := sampleDomains[i%len(sampleDomains)]
path := generateRandomPath()
url := fmt.Sprintf("https://%s%s", domain, path)
content := fmt.Sprintf("Autogenerated seed post #%d\n\nThis is sample content for testing purposes. The post discusses %s and provides valuable insights.", i+1, title)
post := &database.Post{
Title: title,
URL: url,
Content: content,
AuthorID: &authorID,
UpVotes: 0,
DownVotes: 0,
Score: 0,
}
if err := postRepo.Create(post); err != nil {
return nil, fmt.Errorf("create post %d: %w", i+1, err)
}
posts = append(posts, *post)
}
return posts, nil
}
func generateRandomPath() string {
initSeedRand()
pathLength := seedRandSource.Intn(20)
path := "/article/"
for i := 0; i < pathLength+5; i++ {
randomChar := seedRandSource.Intn(26)
path += string(rune('a' + randomChar))
}
return path
}
func createRandomVotes(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post, avgVotesPerPost int) (int, error) {
initSeedRand()
totalVotes := 0
for _, post := range posts {
numVotes := seedRandSource.Intn(avgVotesPerPost*2 + 1)
if numVotes == 0 && avgVotesPerPost > 0 {
if seedRandSource.Intn(5) > 0 {
numVotes = 1
}
}
usedUsers := make(map[uint]bool)
for i := 0; i < numVotes && len(usedUsers) < len(users); i++ {
userIdx := seedRandSource.Intn(len(users))
user := users[userIdx]
if usedUsers[user.ID] {
continue
}
usedUsers[user.ID] = true
voteTypeInt := seedRandSource.Intn(10)
var voteType database.VoteType
if voteTypeInt < 7 {
voteType = database.VoteUp
} else {
voteType = database.VoteDown
}
vote := &database.Vote{
UserID: &user.ID,
PostID: post.ID,
Type: voteType,
}
if err := voteRepo.Create(vote); err != nil {
return totalVotes, fmt.Errorf("create vote for post %d: %w", post.ID, err)
}
totalVotes++
}
}
return totalVotes, nil
}
func updatePostScores(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post) error {
for _, post := range posts {
upVotes, downVotes, err := getVoteCounts(voteRepo, post.ID)
if err != nil {
return fmt.Errorf("get vote counts for post %d: %w", post.ID, err)
}
post.UpVotes = upVotes
post.DownVotes = downVotes
post.Score = upVotes - downVotes
if err := postRepo.Update(&post); err != nil {
return fmt.Errorf("update post %d scores: %w", post.ID, err)
}
}
return nil
return user, nil
}
func getVoteCounts(voteRepo repositories.VoteRepository, postID uint) (int, int, error) {
votes, err := voteRepo.GetByPostID(postID)
if err != nil {
return 0, 0, err
}
upVotes := 0
downVotes := 0
for _, vote := range votes {
switch vote.Type {
case database.VoteUp:
upVotes++
case database.VoteDown:
downVotes++
}
}
return upVotes, downVotes, nil
return voteRepo.GetVoteCountsByPostID(postID)
}
func validateSeedConsistency(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, users []database.User, posts []database.Post) error {
userIDs := make(map[uint]bool)
func validateSeedConsistency(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post) error {
userIDSet := make(map[uint]struct{}, len(users))
for _, user := range users {
userIDs[user.ID] = true
userIDSet[user.ID] = struct{}{}
}
postIDSet := make(map[uint]struct{}, len(posts))
for _, post := range posts {
postIDSet[post.ID] = struct{}{}
}
for _, post := range posts {
if post.AuthorID == nil {
return fmt.Errorf("post %d has no author", post.ID)
}
if !userIDs[*post.AuthorID] {
return fmt.Errorf("post %d has invalid author ID %d", post.ID, *post.AuthorID)
if err := validatePost(post, userIDSet); err != nil {
return err
}
votes, err := voteRepo.GetByPostID(post.ID)
if err != nil {
return fmt.Errorf("get votes for post %d: %w", post.ID, err)
return fmt.Errorf("failed to retrieve votes for post %d: %w", post.ID, err)
}
for _, vote := range votes {
if vote.UserID != nil && !userIDs[*vote.UserID] {
return fmt.Errorf("vote %d has invalid user ID %d", vote.ID, *vote.UserID)
}
if vote.PostID != post.ID {
return fmt.Errorf("vote %d has invalid post ID %d (expected %d)", vote.ID, vote.PostID, post.ID)
}
if err := validateVotesForPost(post.ID, votes, userIDSet, postIDSet); err != nil {
return err
}
}
return nil
}
func validatePost(post database.Post, userIDSet map[uint]struct{}) error {
if post.AuthorID == nil {
return fmt.Errorf("post %d has no author ID", post.ID)
}
if _, exists := userIDSet[*post.AuthorID]; !exists {
return fmt.Errorf("post %d references non-existent author ID %d", post.ID, *post.AuthorID)
}
return nil
}
func validateVotesForPost(postID uint, votes []database.Vote, userIDSet map[uint]struct{}, postIDSet map[uint]struct{}) error {
for _, vote := range votes {
if vote.PostID != postID {
return fmt.Errorf("vote %d references post ID %d but was retrieved for post %d", vote.ID, vote.PostID, postID)
}
if _, exists := postIDSet[vote.PostID]; !exists {
return fmt.Errorf("vote %d references non-existent post ID %d", vote.ID, vote.PostID)
}
if vote.UserID != nil {
if _, exists := userIDSet[*vote.UserID]; !exists {
return fmt.Errorf("vote %d references non-existent user ID %d", vote.ID, *vote.UserID)
}
}
if vote.Type != database.VoteUp && vote.Type != database.VoteDown {
return fmt.Errorf("vote %d has invalid type %q", vote.ID, vote.Type)
}
}

View File

@@ -9,7 +9,6 @@ import (
"goyco/internal/repositories"
"goyco/internal/testutils"
"golang.org/x/crypto/bcrypt"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
@@ -25,49 +24,64 @@ func TestSeedCommand(t *testing.T) {
t.Fatalf("Failed to migrate database: %v", err)
}
err = db.Transaction(func(tx *gorm.DB) error {
userRepo := repositories.NewUserRepository(db).WithTx(tx)
postRepo := repositories.NewPostRepository(db).WithTx(tx)
voteRepo := repositories.NewVoteRepository(db).WithTx(tx)
return seedDatabase(userRepo, postRepo, voteRepo, []string{"--users", "2", "--posts", "5", "--votes-per-post", "3"})
})
if err != nil {
t.Fatalf("Failed to seed database: %v", err)
}
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
voteRepo := repositories.NewVoteRepository(db)
seedPasswordHash, err := bcrypt.GenerateFromPassword([]byte("seed-password"), bcrypt.DefaultCost)
users, err := userRepo.GetAll(100, 0)
if err != nil {
t.Fatalf("Failed to hash seed password: %v", err)
t.Fatalf("Failed to get users: %v", err)
}
seedUser, err := ensureSeedUser(userRepo, string(seedPasswordHash))
if err != nil {
t.Fatalf("Failed to ensure seed user: %v", err)
seedUserCount := 0
var seedUser *database.User
regularUserCount := 0
for i := range users {
if users[i].Username == "seed_admin" {
seedUserCount++
seedUser = &users[i]
} else if strings.HasPrefix(users[i].Username, "user_") {
regularUserCount++
}
}
if !strings.HasPrefix(seedUser.Username, "seed_admin_") {
t.Errorf("Expected username to start with 'seed_admin_', got '%s'", seedUser.Username)
if seedUserCount != 1 {
t.Errorf("Expected 1 seed user, got %d", seedUserCount)
}
if !strings.HasPrefix(seedUser.Email, "seed_admin_") || !strings.HasSuffix(seedUser.Email, "@goyco.local") {
t.Errorf("Expected email to start with 'seed_admin_' and end with '@goyco.local', got '%s'", seedUser.Email)
if seedUser == nil {
t.Fatal("Expected seed user to be created")
}
if seedUser.Username != "seed_admin" {
t.Errorf("Expected username to be 'seed_admin', got '%s'", seedUser.Username)
}
if seedUser.Email != "seed_admin@goyco.local" {
t.Errorf("Expected email to be 'seed_admin@goyco.local', got '%s'", seedUser.Email)
}
if !seedUser.EmailVerified {
t.Error("Expected seed user to be email verified")
}
userPasswordHash, err := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
if err != nil {
t.Fatalf("Failed to hash user password: %v", err)
if regularUserCount != 2 {
t.Errorf("Expected 2 regular users, got %d", regularUserCount)
}
users, err := createRandomUsers(userRepo, 2, string(userPasswordHash))
posts, err := postRepo.GetAll(100, 0)
if err != nil {
t.Fatalf("Failed to create random users: %v", err)
}
if len(users) != 2 {
t.Errorf("Expected 2 users, got %d", len(users))
}
posts, err := createRandomPosts(postRepo, seedUser.ID, 5)
if err != nil {
t.Fatalf("Failed to create random posts: %v", err)
t.Fatalf("Failed to get posts: %v", err)
}
if len(posts) != 5 {
@@ -84,39 +98,49 @@ func TestSeedCommand(t *testing.T) {
if post.AuthorID == nil || *post.AuthorID != seedUser.ID {
t.Errorf("Post %d has wrong author ID: expected %d, got %v", i, seedUser.ID, post.AuthorID)
}
expectedScore := post.UpVotes - post.DownVotes
if post.Score != expectedScore {
t.Errorf("Post %d has incorrect score: expected %d, got %d", i, expectedScore, post.Score)
}
}
allUsers := append([]database.User{*seedUser}, users...)
votes, err := createRandomVotes(voteRepo, allUsers, posts, 3)
voteCount, err := voteRepo.Count()
if err != nil {
t.Fatalf("Failed to create random votes: %v", err)
t.Fatalf("Failed to count votes: %v", err)
}
if votes == 0 {
if voteCount == 0 {
t.Error("Expected some votes to be created")
}
err = updatePostScores(postRepo, voteRepo, posts)
if err != nil {
t.Fatalf("Failed to update post scores: %v", err)
}
for i, post := range posts {
updatedPost, err := postRepo.GetByID(post.ID)
for _, post := range posts {
postVotes, err := voteRepo.GetByPostID(post.ID)
if err != nil {
t.Errorf("Failed to get updated post %d: %v", i, err)
t.Errorf("Failed to get votes for post %d: %v", post.ID, err)
continue
}
expectedScore := updatedPost.UpVotes - updatedPost.DownVotes
if updatedPost.Score != expectedScore {
t.Errorf("Post %d has incorrect score: expected %d, got %d", i, expectedScore, updatedPost.Score)
for _, vote := range postVotes {
if vote.PostID != post.ID {
t.Errorf("Vote has wrong post ID: expected %d, got %d", post.ID, vote.PostID)
}
if vote.UserID == nil {
t.Error("Vote has nil user ID")
}
}
}
}
func TestGenerateRandomPath(t *testing.T) {
path := generateRandomPath()
initSeedRand()
pathLength := seedRandSource.Intn(20)
path := "/article/"
for i := 0; i < pathLength+5; i++ {
randomChar := seedRandSource.Intn(26)
path += string(rune('a' + randomChar))
}
if path == "" {
t.Error("Generated path should not be empty")
@@ -126,7 +150,14 @@ func TestGenerateRandomPath(t *testing.T) {
t.Errorf("Generated path too short: %s", path)
}
secondPath := generateRandomPath()
initSeedRand()
secondPathLength := seedRandSource.Intn(20)
secondPath := "/article/"
for i := 0; i < secondPathLength+5; i++ {
randomChar := seedRandSource.Intn(26)
secondPath += string(rune('a' + randomChar))
}
if path == secondPath {
t.Error("Generated paths should be different")
}
@@ -271,13 +302,13 @@ func TestSeedCommandIdempotency(t *testing.T) {
seedUserCount := 0
for _, user := range users {
if strings.HasPrefix(user.Username, "seed_admin_") {
if user.Username == "seed_admin" {
seedUserCount++
}
}
if seedUserCount < 1 {
t.Errorf("Expected at least 1 seed user, got %d", seedUserCount)
if seedUserCount != 1 {
t.Errorf("Expected exactly 1 seed user, got %d", seedUserCount)
}
})
@@ -356,7 +387,7 @@ func TestSeedCommandIdempotency(t *testing.T) {
func findSeedUser(users []database.User) *database.User {
for i := range users {
if strings.HasPrefix(users[i].Username, "seed_admin_") {
if users[i].Username == "seed_admin" {
return &users[i]
}
}
@@ -445,3 +476,58 @@ func TestSeedCommandTransactionRollback(t *testing.T) {
}
})
}
func TestEnsureSeedUser(t *testing.T) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("Failed to connect to database: %v", err)
}
if err := db.AutoMigrate(&database.User{}); err != nil {
t.Fatalf("Failed to migrate database: %v", err)
}
userRepo := repositories.NewUserRepository(db)
passwordHash := "test_password_hash"
firstUser, err := ensureSeedUser(userRepo, passwordHash)
if err != nil {
t.Fatalf("Failed to create seed user: %v", err)
}
if firstUser.Username != "seed_admin" || firstUser.Email != "seed_admin@goyco.local" || firstUser.Password != passwordHash || !firstUser.EmailVerified {
t.Errorf("Invalid seed user: username=%s, email=%s, password matches=%v, emailVerified=%v",
firstUser.Username, firstUser.Email, firstUser.Password == passwordHash, firstUser.EmailVerified)
}
secondUser, err := ensureSeedUser(userRepo, "different_password_hash")
if err != nil {
t.Fatalf("Failed to reuse seed user: %v", err)
}
if firstUser.ID != secondUser.ID {
t.Errorf("Expected same user to be reused (ID %d), got different user (ID %d)", firstUser.ID, secondUser.ID)
}
for i := 0; i < 3; i++ {
if _, err := ensureSeedUser(userRepo, passwordHash); err != nil {
t.Fatalf("Call %d failed: %v", i+1, err)
}
}
users, err := userRepo.GetAll(100, 0)
if err != nil {
t.Fatalf("Failed to get users: %v", err)
}
seedUserCount := 0
for _, user := range users {
if user.Username == "seed_admin" {
seedUserCount++
}
}
if seedUserCount != 1 {
t.Errorf("Expected exactly 1 seed user, got %d", seedUserCount)
}
}

View File

@@ -96,21 +96,21 @@ func TestServerConfigurationFromConfig(t *testing.T) {
testServer := httptest.NewServer(srv.Handler)
defer testServer.Close()
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil)
request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
resp, err := http.DefaultClient.Do(req)
response, err := http.DefaultClient.Do(request)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
defer func() {
_ = resp.Body.Close()
_ = response.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
if response.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", response.StatusCode)
}
}
@@ -208,27 +208,27 @@ func TestTLSWiringFromConfig(t *testing.T) {
},
}
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil)
request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
resp, err := client.Do(req)
response, err := client.Do(request)
if err != nil {
t.Fatalf("Failed to make TLS request: %v", err)
}
defer func() {
_ = resp.Body.Close()
_ = response.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 over TLS, got %d", resp.StatusCode)
if response.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 over TLS, got %d", response.StatusCode)
}
if resp.TLS == nil {
if response.TLS == nil {
t.Error("Expected TLS connection info to be present in response")
} else if resp.TLS.Version < tls.VersionTLS12 {
t.Errorf("Expected TLS version 1.2 or higher, got %x", resp.TLS.Version)
} else if response.TLS.Version < tls.VersionTLS12 {
t.Errorf("Expected TLS version 1.2 or higher, got %x", response.TLS.Version)
}
}
}
@@ -368,38 +368,38 @@ func TestServerInitializationFlow(t *testing.T) {
testServer := httptest.NewServer(srv.Handler)
defer testServer.Close()
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil)
request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
resp, err := http.DefaultClient.Do(req)
response, err := http.DefaultClient.Do(request)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
defer func() {
_ = resp.Body.Close()
_ = response.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
if response.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", response.StatusCode)
}
req, err = http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/api", nil)
request, err = http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/api", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
resp, err = http.DefaultClient.Do(req)
response, err = http.DefaultClient.Do(request)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
defer func() {
_ = resp.Body.Close()
_ = response.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 for API endpoint, got %d", resp.StatusCode)
if response.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 for API endpoint, got %d", response.StatusCode)
}
}

View File

@@ -111,7 +111,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ConfirmAccountDeletionRequest"
"$ref": "#/definitions/dto.ConfirmAccountDeletionRequest"
}
}
],
@@ -212,7 +212,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdateEmailRequest"
"$ref": "#/definitions/dto.UpdateEmailRequest"
}
}
],
@@ -276,7 +276,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ForgotPasswordRequest"
"$ref": "#/definitions/dto.ForgotPasswordRequest"
}
}
],
@@ -316,7 +316,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.LoginRequest"
"$ref": "#/definitions/dto.LoginRequest"
}
}
],
@@ -453,7 +453,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdatePasswordRequest"
"$ref": "#/definitions/dto.UpdatePasswordRequest"
}
}
],
@@ -505,7 +505,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RefreshTokenRequest"
"$ref": "#/definitions/dto.RefreshTokenRequest"
}
}
],
@@ -563,7 +563,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RegisterRequest"
"$ref": "#/definitions/dto.RegisterRequest"
}
}
],
@@ -615,7 +615,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ResendVerificationRequest"
"$ref": "#/definitions/dto.ResendVerificationRequest"
}
}
],
@@ -685,7 +685,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ResetPasswordRequest"
"$ref": "#/definitions/dto.ResetPasswordRequest"
}
}
],
@@ -736,7 +736,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RevokeTokenRequest"
"$ref": "#/definitions/dto.RevokeTokenRequest"
}
}
],
@@ -833,7 +833,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdateUsernameRequest"
"$ref": "#/definitions/dto.UpdateUsernameRequest"
}
}
],
@@ -945,7 +945,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.CreatePostRequest"
"$ref": "#/definitions/dto.CreatePostRequest"
}
}
],
@@ -1176,7 +1176,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdatePostRequest"
"$ref": "#/definitions/dto.UpdatePostRequest"
}
}
],
@@ -1370,7 +1370,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.VoteRequest"
"$ref": "#/definitions/dto.CastVoteRequest"
}
}
],
@@ -1601,7 +1601,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RegisterRequest"
"$ref": "#/definitions/dto.RegisterRequest"
}
}
],
@@ -1817,6 +1817,223 @@ const docTemplate = `{
}
},
"definitions": {
"dto.CastVoteRequest": {
"type": "object",
"required": [
"type"
],
"properties": {
"type": {
"type": "string",
"enum": [
"up",
"down",
"none"
]
}
}
},
"dto.ConfirmAccountDeletionRequest": {
"type": "object",
"required": [
"token"
],
"properties": {
"delete_posts": {
"type": "boolean"
},
"token": {
"type": "string"
}
}
},
"dto.CreatePostRequest": {
"type": "object",
"required": [
"url"
],
"properties": {
"content": {
"type": "string",
"maxLength": 10000
},
"title": {
"type": "string",
"maxLength": 200,
"minLength": 3
},
"url": {
"type": "string",
"maxLength": 2048
}
}
},
"dto.ForgotPasswordRequest": {
"type": "object",
"required": [
"username_or_email"
],
"properties": {
"username_or_email": {
"type": "string"
}
}
},
"dto.LoginRequest": {
"type": "object",
"required": [
"password",
"username"
],
"properties": {
"password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"dto.RefreshTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"dto.RegisterRequest": {
"type": "object",
"required": [
"email",
"password",
"username"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
},
"password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"dto.ResendVerificationRequest": {
"type": "object",
"required": [
"email"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
}
}
},
"dto.ResetPasswordRequest": {
"type": "object",
"required": [
"new_password",
"token"
],
"properties": {
"new_password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"token": {
"type": "string"
}
}
},
"dto.RevokeTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"dto.UpdateEmailRequest": {
"type": "object",
"required": [
"email"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
}
}
},
"dto.UpdatePasswordRequest": {
"type": "object",
"required": [
"current_password",
"new_password"
],
"properties": {
"current_password": {
"type": "string"
},
"new_password": {
"type": "string",
"maxLength": 128,
"minLength": 8
}
}
},
"dto.UpdatePostRequest": {
"type": "object",
"required": [
"title"
],
"properties": {
"content": {
"type": "string",
"maxLength": 10000
},
"title": {
"type": "string",
"maxLength": 200,
"minLength": 3
}
}
},
"dto.UpdateUsernameRequest": {
"type": "object",
"required": [
"username"
],
"properties": {
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"handlers.APIInfo": {
"type": "object",
"properties": {
@@ -1919,50 +2136,6 @@ const docTemplate = `{
}
}
},
"handlers.ConfirmAccountDeletionRequest": {
"type": "object",
"properties": {
"delete_posts": {
"type": "boolean"
},
"token": {
"type": "string"
}
}
},
"handlers.CreatePostRequest": {
"type": "object",
"properties": {
"content": {
"type": "string"
},
"title": {
"type": "string"
},
"url": {
"type": "string"
}
}
},
"handlers.ForgotPasswordRequest": {
"type": "object",
"properties": {
"username_or_email": {
"type": "string"
}
}
},
"handlers.LoginRequest": {
"type": "object",
"properties": {
"password": {
"type": "string"
},
"username": {
"type": "string"
}
}
},
"handlers.PostResponse": {
"type": "object",
"properties": {
@@ -1978,101 +2151,6 @@ const docTemplate = `{
}
}
},
"handlers.RefreshTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"handlers.RegisterRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
},
"password": {
"type": "string"
},
"username": {
"type": "string"
}
}
},
"handlers.ResendVerificationRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
}
}
},
"handlers.ResetPasswordRequest": {
"type": "object",
"properties": {
"new_password": {
"type": "string"
},
"token": {
"type": "string"
}
}
},
"handlers.RevokeTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"handlers.UpdateEmailRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
}
}
},
"handlers.UpdatePasswordRequest": {
"type": "object",
"properties": {
"current_password": {
"type": "string"
},
"new_password": {
"type": "string"
}
}
},
"handlers.UpdatePostRequest": {
"type": "object",
"properties": {
"content": {
"type": "string"
},
"title": {
"type": "string"
}
}
},
"handlers.UpdateUsernameRequest": {
"type": "object",
"properties": {
"username": {
"type": "string"
}
}
},
"handlers.UserResponse": {
"type": "object",
"properties": {
@@ -2088,21 +2166,6 @@ const docTemplate = `{
}
}
},
"handlers.VoteRequest": {
"description": "Vote request with type field. All votes are handled the same way.",
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": [
"up",
"down",
"none"
],
"example": "up"
}
}
},
"handlers.VoteResponse": {
"type": "object",
"properties": {

View File

@@ -108,7 +108,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ConfirmAccountDeletionRequest"
"$ref": "#/definitions/dto.ConfirmAccountDeletionRequest"
}
}
],
@@ -209,7 +209,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdateEmailRequest"
"$ref": "#/definitions/dto.UpdateEmailRequest"
}
}
],
@@ -273,7 +273,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ForgotPasswordRequest"
"$ref": "#/definitions/dto.ForgotPasswordRequest"
}
}
],
@@ -313,7 +313,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.LoginRequest"
"$ref": "#/definitions/dto.LoginRequest"
}
}
],
@@ -450,7 +450,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdatePasswordRequest"
"$ref": "#/definitions/dto.UpdatePasswordRequest"
}
}
],
@@ -502,7 +502,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RefreshTokenRequest"
"$ref": "#/definitions/dto.RefreshTokenRequest"
}
}
],
@@ -560,7 +560,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RegisterRequest"
"$ref": "#/definitions/dto.RegisterRequest"
}
}
],
@@ -612,7 +612,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ResendVerificationRequest"
"$ref": "#/definitions/dto.ResendVerificationRequest"
}
}
],
@@ -682,7 +682,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ResetPasswordRequest"
"$ref": "#/definitions/dto.ResetPasswordRequest"
}
}
],
@@ -733,7 +733,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RevokeTokenRequest"
"$ref": "#/definitions/dto.RevokeTokenRequest"
}
}
],
@@ -830,7 +830,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdateUsernameRequest"
"$ref": "#/definitions/dto.UpdateUsernameRequest"
}
}
],
@@ -942,7 +942,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.CreatePostRequest"
"$ref": "#/definitions/dto.CreatePostRequest"
}
}
],
@@ -1173,7 +1173,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdatePostRequest"
"$ref": "#/definitions/dto.UpdatePostRequest"
}
}
],
@@ -1367,7 +1367,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.VoteRequest"
"$ref": "#/definitions/dto.CastVoteRequest"
}
}
],
@@ -1598,7 +1598,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RegisterRequest"
"$ref": "#/definitions/dto.RegisterRequest"
}
}
],
@@ -1814,6 +1814,223 @@
}
},
"definitions": {
"dto.CastVoteRequest": {
"type": "object",
"required": [
"type"
],
"properties": {
"type": {
"type": "string",
"enum": [
"up",
"down",
"none"
]
}
}
},
"dto.ConfirmAccountDeletionRequest": {
"type": "object",
"required": [
"token"
],
"properties": {
"delete_posts": {
"type": "boolean"
},
"token": {
"type": "string"
}
}
},
"dto.CreatePostRequest": {
"type": "object",
"required": [
"url"
],
"properties": {
"content": {
"type": "string",
"maxLength": 10000
},
"title": {
"type": "string",
"maxLength": 200,
"minLength": 3
},
"url": {
"type": "string",
"maxLength": 2048
}
}
},
"dto.ForgotPasswordRequest": {
"type": "object",
"required": [
"username_or_email"
],
"properties": {
"username_or_email": {
"type": "string"
}
}
},
"dto.LoginRequest": {
"type": "object",
"required": [
"password",
"username"
],
"properties": {
"password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"dto.RefreshTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"dto.RegisterRequest": {
"type": "object",
"required": [
"email",
"password",
"username"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
},
"password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"dto.ResendVerificationRequest": {
"type": "object",
"required": [
"email"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
}
}
},
"dto.ResetPasswordRequest": {
"type": "object",
"required": [
"new_password",
"token"
],
"properties": {
"new_password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"token": {
"type": "string"
}
}
},
"dto.RevokeTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"dto.UpdateEmailRequest": {
"type": "object",
"required": [
"email"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
}
}
},
"dto.UpdatePasswordRequest": {
"type": "object",
"required": [
"current_password",
"new_password"
],
"properties": {
"current_password": {
"type": "string"
},
"new_password": {
"type": "string",
"maxLength": 128,
"minLength": 8
}
}
},
"dto.UpdatePostRequest": {
"type": "object",
"required": [
"title"
],
"properties": {
"content": {
"type": "string",
"maxLength": 10000
},
"title": {
"type": "string",
"maxLength": 200,
"minLength": 3
}
}
},
"dto.UpdateUsernameRequest": {
"type": "object",
"required": [
"username"
],
"properties": {
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"handlers.APIInfo": {
"type": "object",
"properties": {
@@ -1916,50 +2133,6 @@
}
}
},
"handlers.ConfirmAccountDeletionRequest": {
"type": "object",
"properties": {
"delete_posts": {
"type": "boolean"
},
"token": {
"type": "string"
}
}
},
"handlers.CreatePostRequest": {
"type": "object",
"properties": {
"content": {
"type": "string"
},
"title": {
"type": "string"
},
"url": {
"type": "string"
}
}
},
"handlers.ForgotPasswordRequest": {
"type": "object",
"properties": {
"username_or_email": {
"type": "string"
}
}
},
"handlers.LoginRequest": {
"type": "object",
"properties": {
"password": {
"type": "string"
},
"username": {
"type": "string"
}
}
},
"handlers.PostResponse": {
"type": "object",
"properties": {
@@ -1975,101 +2148,6 @@
}
}
},
"handlers.RefreshTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"handlers.RegisterRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
},
"password": {
"type": "string"
},
"username": {
"type": "string"
}
}
},
"handlers.ResendVerificationRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
}
}
},
"handlers.ResetPasswordRequest": {
"type": "object",
"properties": {
"new_password": {
"type": "string"
},
"token": {
"type": "string"
}
}
},
"handlers.RevokeTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"handlers.UpdateEmailRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
}
}
},
"handlers.UpdatePasswordRequest": {
"type": "object",
"properties": {
"current_password": {
"type": "string"
},
"new_password": {
"type": "string"
}
}
},
"handlers.UpdatePostRequest": {
"type": "object",
"properties": {
"content": {
"type": "string"
},
"title": {
"type": "string"
}
}
},
"handlers.UpdateUsernameRequest": {
"type": "object",
"properties": {
"username": {
"type": "string"
}
}
},
"handlers.UserResponse": {
"type": "object",
"properties": {
@@ -2085,21 +2163,6 @@
}
}
},
"handlers.VoteRequest": {
"description": "Vote request with type field. All votes are handled the same way.",
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": [
"up",
"down",
"none"
],
"example": "up"
}
}
},
"handlers.VoteResponse": {
"type": "object",
"properties": {

View File

@@ -1,5 +1,156 @@
basePath: /api
definitions:
dto.CastVoteRequest:
properties:
type:
enum:
- up
- down
- none
type: string
required:
- type
type: object
dto.ConfirmAccountDeletionRequest:
properties:
delete_posts:
type: boolean
token:
type: string
required:
- token
type: object
dto.CreatePostRequest:
properties:
content:
maxLength: 10000
type: string
title:
maxLength: 200
minLength: 3
type: string
url:
maxLength: 2048
type: string
required:
- url
type: object
dto.ForgotPasswordRequest:
properties:
username_or_email:
type: string
required:
- username_or_email
type: object
dto.LoginRequest:
properties:
password:
maxLength: 128
minLength: 8
type: string
username:
maxLength: 50
minLength: 3
type: string
required:
- password
- username
type: object
dto.RefreshTokenRequest:
properties:
refresh_token:
example: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
type: string
required:
- refresh_token
type: object
dto.RegisterRequest:
properties:
email:
maxLength: 254
type: string
password:
maxLength: 128
minLength: 8
type: string
username:
maxLength: 50
minLength: 3
type: string
required:
- email
- password
- username
type: object
dto.ResendVerificationRequest:
properties:
email:
maxLength: 254
type: string
required:
- email
type: object
dto.ResetPasswordRequest:
properties:
new_password:
maxLength: 128
minLength: 8
type: string
token:
type: string
required:
- new_password
- token
type: object
dto.RevokeTokenRequest:
properties:
refresh_token:
example: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
type: string
required:
- refresh_token
type: object
dto.UpdateEmailRequest:
properties:
email:
maxLength: 254
type: string
required:
- email
type: object
dto.UpdatePasswordRequest:
properties:
current_password:
type: string
new_password:
maxLength: 128
minLength: 8
type: string
required:
- current_password
- new_password
type: object
dto.UpdatePostRequest:
properties:
content:
maxLength: 10000
type: string
title:
maxLength: 200
minLength: 3
type: string
required:
- title
type: object
dto.UpdateUsernameRequest:
properties:
username:
maxLength: 50
minLength: 3
type: string
required:
- username
type: object
handlers.APIInfo:
properties:
data: {}
@@ -70,34 +221,6 @@ definitions:
success:
type: boolean
type: object
handlers.ConfirmAccountDeletionRequest:
properties:
delete_posts:
type: boolean
token:
type: string
type: object
handlers.CreatePostRequest:
properties:
content:
type: string
title:
type: string
url:
type: string
type: object
handlers.ForgotPasswordRequest:
properties:
username_or_email:
type: string
type: object
handlers.LoginRequest:
properties:
password:
type: string
username:
type: string
type: object
handlers.PostResponse:
properties:
data: {}
@@ -108,67 +231,6 @@ definitions:
success:
type: boolean
type: object
handlers.RefreshTokenRequest:
properties:
refresh_token:
example: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
type: string
required:
- refresh_token
type: object
handlers.RegisterRequest:
properties:
email:
type: string
password:
type: string
username:
type: string
type: object
handlers.ResendVerificationRequest:
properties:
email:
type: string
type: object
handlers.ResetPasswordRequest:
properties:
new_password:
type: string
token:
type: string
type: object
handlers.RevokeTokenRequest:
properties:
refresh_token:
example: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
type: string
required:
- refresh_token
type: object
handlers.UpdateEmailRequest:
properties:
email:
type: string
type: object
handlers.UpdatePasswordRequest:
properties:
current_password:
type: string
new_password:
type: string
type: object
handlers.UpdatePostRequest:
properties:
content:
type: string
title:
type: string
type: object
handlers.UpdateUsernameRequest:
properties:
username:
type: string
type: object
handlers.UserResponse:
properties:
data: {}
@@ -179,17 +241,6 @@ definitions:
success:
type: boolean
type: object
handlers.VoteRequest:
description: Vote request with type field. All votes are handled the same way.
properties:
type:
enum:
- up
- down
- none
example: up
type: string
type: object
handlers.VoteResponse:
properties:
data: {}
@@ -268,7 +319,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.ConfirmAccountDeletionRequest'
$ref: '#/definitions/dto.ConfirmAccountDeletionRequest'
produces:
- application/json
responses:
@@ -331,7 +382,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.UpdateEmailRequest'
$ref: '#/definitions/dto.UpdateEmailRequest'
produces:
- application/json
responses:
@@ -375,7 +426,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.ForgotPasswordRequest'
$ref: '#/definitions/dto.ForgotPasswordRequest'
produces:
- application/json
responses:
@@ -401,7 +452,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.LoginRequest'
$ref: '#/definitions/dto.LoginRequest'
produces:
- application/json
responses:
@@ -485,7 +536,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.UpdatePasswordRequest'
$ref: '#/definitions/dto.UpdatePasswordRequest'
produces:
- application/json
responses:
@@ -523,7 +574,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.RefreshTokenRequest'
$ref: '#/definitions/dto.RefreshTokenRequest'
produces:
- application/json
responses:
@@ -561,7 +612,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.RegisterRequest'
$ref: '#/definitions/dto.RegisterRequest'
produces:
- application/json
responses:
@@ -595,7 +646,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.ResendVerificationRequest'
$ref: '#/definitions/dto.ResendVerificationRequest'
produces:
- application/json
responses:
@@ -641,7 +692,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.ResetPasswordRequest'
$ref: '#/definitions/dto.ResetPasswordRequest'
produces:
- application/json
responses:
@@ -672,7 +723,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.RevokeTokenRequest'
$ref: '#/definitions/dto.RevokeTokenRequest'
produces:
- application/json
responses:
@@ -735,7 +786,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.UpdateUsernameRequest'
$ref: '#/definitions/dto.UpdateUsernameRequest'
produces:
- application/json
responses:
@@ -809,7 +860,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.CreatePostRequest'
$ref: '#/definitions/dto.CreatePostRequest'
produces:
- application/json
responses:
@@ -933,7 +984,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.UpdatePostRequest'
$ref: '#/definitions/dto.UpdatePostRequest'
produces:
- application/json
responses:
@@ -1070,7 +1121,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.VoteRequest'
$ref: '#/definitions/dto.CastVoteRequest'
produces:
- application/json
responses:
@@ -1260,7 +1311,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.RegisterRequest'
$ref: '#/definitions/dto.RegisterRequest'
produces:
- application/json
responses:

View File

@@ -0,0 +1,51 @@
package dto
type LoginRequest struct {
Username string `json:"username" validate:"required,min=3,max=50"`
Password string `json:"password" validate:"required,min=8,max=128"`
}
type RegisterRequest struct {
Username string `json:"username" validate:"required,min=3,max=50"`
Email string `json:"email" validate:"required,email,max=254"`
Password string `json:"password" validate:"required,min=8,max=128"`
}
type ResendVerificationRequest struct {
Email string `json:"email" validate:"required,email,max=254"`
}
type ForgotPasswordRequest struct {
UsernameOrEmail string `json:"username_or_email" validate:"required"`
}
type ResetPasswordRequest struct {
Token string `json:"token" validate:"required"`
NewPassword string `json:"new_password" validate:"required,min=8,max=128"`
}
type UpdateEmailRequest struct {
Email string `json:"email" validate:"required,email,max=254"`
}
type UpdateUsernameRequest struct {
Username string `json:"username" validate:"required,min=3,max=50"`
}
type UpdatePasswordRequest struct {
CurrentPassword string `json:"current_password" validate:"required"`
NewPassword string `json:"new_password" validate:"required,min=8,max=128"`
}
type ConfirmAccountDeletionRequest struct {
Token string `json:"token" validate:"required"`
DeletePosts bool `json:"delete_posts"`
}
type RefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." validate:"required"`
}
type RevokeTokenRequest struct {
RefreshToken string `json:"refresh_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." validate:"required"`
}

View File

@@ -0,0 +1,12 @@
package dto
type CreatePostRequest struct {
Title string `json:"title" validate:"omitempty,min=3,max=200"`
URL string `json:"url" validate:"required,url,max=2048"`
Content string `json:"content" validate:"omitempty,max=10000"`
}
type UpdatePostRequest struct {
Title string `json:"title" validate:"required,min=3,max=200"`
Content string `json:"content" validate:"omitempty,max=10000"`
}

View File

@@ -6,6 +6,10 @@ import (
"goyco/internal/database"
)
type CastVoteRequest struct {
Type string `json:"type" validate:"required,oneof=up down none"`
}
type VoteDTO struct {
ID uint `json:"id"`
UserID *uint `json:"user_id,omitempty"`

View File

@@ -112,37 +112,37 @@ func newInMemoryRoundTripper(handler http.Handler) http.RoundTripper {
return &inMemoryRoundTripper{handler: handler}
}
func (rt *inMemoryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
func (rt *inMemoryRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
if rt == nil || rt.handler == nil {
return nil, fmt.Errorf("in-memory round tripper not initialized")
}
var bodyBytes []byte
if req.Body != nil && req.Body != http.NoBody {
defer req.Body.Close()
if request.Body != nil && request.Body != http.NoBody {
defer request.Body.Close()
var err error
bodyBytes, err = io.ReadAll(req.Body)
bodyBytes, err = io.ReadAll(request.Body)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %w", err)
}
}
if len(bodyBytes) > 0 {
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
} else {
req.Body = http.NoBody
request.Body = http.NoBody
}
clonedReq := req.Clone(req.Context())
clonedRequest := request.Clone(request.Context())
if len(bodyBytes) > 0 {
clonedReq.Body = io.NopCloser(bytes.NewReader(bodyBytes))
clonedRequest.Body = io.NopCloser(bytes.NewReader(bodyBytes))
} else {
clonedReq.Body = http.NoBody
clonedRequest.Body = http.NoBody
}
clonedReq.RequestURI = clonedReq.URL.RequestURI()
clonedRequest.RequestURI = clonedRequest.URL.RequestURI()
recorder := httptest.NewRecorder()
rt.handler.ServeHTTP(recorder, clonedReq)
rt.handler.ServeHTTP(recorder, clonedRequest)
resp := recorder.Result()
return resp, nil
}
@@ -250,7 +250,7 @@ func tokenHash(token string) string {
func retryOnRateLimit(t *testing.T, maxRetries int, operation func() int) int {
t.Helper()
for attempt := 0; attempt < maxRetries; attempt++ {
for attempt := range maxRetries {
statusCode := operation()
if statusCode != http.StatusTooManyRequests {
return statusCode

View File

@@ -15,22 +15,22 @@ func TestE2E_CompressionMiddleware(t *testing.T) {
ctx := setupTestContext(t)
t.Run("compression_enabled_with_accept_encoding", func(t *testing.T) {
req, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
request, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Accept-Encoding", "gzip")
testutils.WithStandardHeaders(req)
request.Header.Set("Accept-Encoding", "gzip")
testutils.WithStandardHeaders(request)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
defer response.Body.Close()
contentEncoding := resp.Header.Get("Content-Encoding")
contentEncoding := response.Header.Get("Content-Encoding")
if contentEncoding == "gzip" {
body, err := io.ReadAll(resp.Body)
body, err := io.ReadAll(response.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
@@ -57,19 +57,19 @@ func TestE2E_CompressionMiddleware(t *testing.T) {
})
t.Run("no_compression_without_accept_encoding", func(t *testing.T) {
req, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
request, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
testutils.WithStandardHeaders(req)
testutils.WithStandardHeaders(request)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
defer response.Body.Close()
contentEncoding := resp.Header.Get("Content-Encoding")
contentEncoding := response.Header.Get("Content-Encoding")
if contentEncoding == "gzip" {
t.Error("Expected no compression without Accept-Encoding header")
}
@@ -85,22 +85,22 @@ func TestE2E_CompressionMiddleware(t *testing.T) {
gz.Write([]byte(postData))
gz.Close()
req, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", &buf)
request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", &buf)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Encoding", "gzip")
testutils.WithStandardHeaders(req)
req.Header.Set("Authorization", "Bearer "+authClient.Token)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Content-Encoding", "gzip")
testutils.WithStandardHeaders(request)
request.Header.Set("Authorization", "Bearer "+authClient.Token)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
defer response.Body.Close()
switch resp.StatusCode {
switch response.StatusCode {
case http.StatusBadRequest:
t.Log("Decompression middleware rejected invalid gzip")
case http.StatusCreated, http.StatusOK:
@@ -113,37 +113,37 @@ func TestE2E_CacheMiddleware(t *testing.T) {
ctx := setupTestContext(t)
t.Run("cache_miss_then_hit", func(t *testing.T) {
req1, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
firstRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
testutils.WithStandardHeaders(req1)
testutils.WithStandardHeaders(firstRequest)
resp1, err := ctx.client.Do(req1)
firstResponse, err := ctx.client.Do(firstRequest)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
resp1.Body.Close()
firstResponse.Body.Close()
cacheStatus1 := resp1.Header.Get("X-Cache")
if cacheStatus1 == "HIT" {
firstCacheStatus := firstResponse.Header.Get("X-Cache")
if firstCacheStatus == "HIT" {
t.Log("First request was cached (unexpected but acceptable)")
}
req2, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
secondRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
testutils.WithStandardHeaders(req2)
testutils.WithStandardHeaders(secondRequest)
resp2, err := ctx.client.Do(req2)
secondResponse, err := ctx.client.Do(secondRequest)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp2.Body.Close()
defer secondResponse.Body.Close()
cacheStatus2 := resp2.Header.Get("X-Cache")
if cacheStatus2 == "HIT" {
secondCacheStatus := secondResponse.Header.Get("X-Cache")
if secondCacheStatus == "HIT" {
t.Log("Second request was served from cache")
}
})
@@ -152,48 +152,48 @@ func TestE2E_CacheMiddleware(t *testing.T) {
testUser := ctx.createUserWithCleanup(t, "cacheuser", "StrongPass123!")
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
req1, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
firstRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
testutils.WithStandardHeaders(req1)
req1.Header.Set("Authorization", "Bearer "+authClient.Token)
testutils.WithStandardHeaders(firstRequest)
firstRequest.Header.Set("Authorization", "Bearer "+authClient.Token)
resp1, err := ctx.client.Do(req1)
firstResponse, err := ctx.client.Do(firstRequest)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
resp1.Body.Close()
firstResponse.Body.Close()
postData := `{"title":"Cache Invalidation Test","url":"https://example.com/cache","content":"Test"}`
req2, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
secondRequest, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req2.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req2)
req2.Header.Set("Authorization", "Bearer "+authClient.Token)
secondRequest.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(secondRequest)
secondRequest.Header.Set("Authorization", "Bearer "+authClient.Token)
resp2, err := ctx.client.Do(req2)
secondResponse, err := ctx.client.Do(secondRequest)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
resp2.Body.Close()
secondResponse.Body.Close()
req3, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
thirdRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
testutils.WithStandardHeaders(req3)
req3.Header.Set("Authorization", "Bearer "+authClient.Token)
testutils.WithStandardHeaders(thirdRequest)
thirdRequest.Header.Set("Authorization", "Bearer "+authClient.Token)
resp3, err := ctx.client.Do(req3)
thirdResponse, err := ctx.client.Do(thirdRequest)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp3.Body.Close()
defer thirdResponse.Body.Close()
cacheStatus := resp3.Header.Get("X-Cache")
cacheStatus := thirdResponse.Header.Get("X-Cache")
if cacheStatus == "HIT" {
t.Log("Cache was invalidated after POST")
}
@@ -204,23 +204,23 @@ func TestE2E_CSRFProtection(t *testing.T) {
ctx := setupTestContext(t)
t.Run("csrf_protection_for_non_api_routes", func(t *testing.T) {
req, err := http.NewRequest("POST", ctx.baseURL+"/auth/login", strings.NewReader(`{"username":"test","password":"test"}`))
request, err := http.NewRequest("POST", ctx.baseURL+"/auth/login", strings.NewReader(`{"username":"test","password":"test"}`))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req)
request.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(request)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
defer response.Body.Close()
if resp.StatusCode == http.StatusForbidden {
if response.StatusCode == http.StatusForbidden {
t.Log("CSRF protection active for non-API routes")
} else {
t.Logf("CSRF check result: status %d", resp.StatusCode)
t.Logf("CSRF check result: status %d", response.StatusCode)
}
})
@@ -229,39 +229,39 @@ func TestE2E_CSRFProtection(t *testing.T) {
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
postData := `{"title":"CSRF Test","url":"https://example.com/csrf","content":"Test"}`
req, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req)
req.Header.Set("Authorization", "Bearer "+authClient.Token)
request.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(request)
request.Header.Set("Authorization", "Bearer "+authClient.Token)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
defer response.Body.Close()
if resp.StatusCode == http.StatusForbidden {
if response.StatusCode == http.StatusForbidden {
t.Error("API routes should bypass CSRF protection")
}
})
t.Run("csrf_allows_get_requests", func(t *testing.T) {
req, err := http.NewRequest("GET", ctx.baseURL+"/auth/login", nil)
request, err := http.NewRequest("GET", ctx.baseURL+"/auth/login", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
testutils.WithStandardHeaders(req)
testutils.WithStandardHeaders(request)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
defer response.Body.Close()
if resp.StatusCode == http.StatusForbidden {
if response.StatusCode == http.StatusForbidden {
t.Error("GET requests should not require CSRF token")
}
})
@@ -276,21 +276,21 @@ func TestE2E_RequestSizeLimit(t *testing.T) {
smallData := strings.Repeat("a", 100)
postData := `{"title":"` + smallData + `","url":"https://example.com","content":"test"}`
req, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req)
req.Header.Set("Authorization", "Bearer "+authClient.Token)
request.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(request)
request.Header.Set("Authorization", "Bearer "+authClient.Token)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
defer response.Body.Close()
if resp.StatusCode == http.StatusRequestEntityTooLarge {
if response.StatusCode == http.StatusRequestEntityTooLarge {
t.Error("Small request should not exceed size limit")
}
})
@@ -301,24 +301,24 @@ func TestE2E_RequestSizeLimit(t *testing.T) {
largeData := strings.Repeat("a", 2*1024*1024)
postData := `{"title":"test","url":"https://example.com","content":"` + largeData + `"}`
req, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req)
req.Header.Set("Authorization", "Bearer "+authClient.Token)
request.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(request)
request.Header.Set("Authorization", "Bearer "+authClient.Token)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
return
}
defer resp.Body.Close()
defer response.Body.Close()
if resp.StatusCode == http.StatusRequestEntityTooLarge {
if response.StatusCode == http.StatusRequestEntityTooLarge {
t.Log("Request size limit enforced correctly")
} else {
t.Logf("Request size limit check result: status %d", resp.StatusCode)
t.Logf("Request size limit check result: status %d", response.StatusCode)
}
})
}

View File

@@ -64,62 +64,6 @@ type AuthUserSummary struct {
Locked bool `json:"locked" example:"false"`
}
type LoginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
}
type RegisterRequest struct {
Username string `json:"username"`
Email string `json:"email"`
Password string `json:"password"`
}
type CreatePostRequest struct {
Title string `json:"title"`
URL string `json:"url"`
Content string `json:"content"`
}
type ResendVerificationRequest struct {
Email string `json:"email"`
}
type ForgotPasswordRequest struct {
UsernameOrEmail string `json:"username_or_email"`
}
type ResetPasswordRequest struct {
Token string `json:"token"`
NewPassword string `json:"new_password"`
}
type UpdateEmailRequest struct {
Email string `json:"email"`
}
type UpdateUsernameRequest struct {
Username string `json:"username"`
}
type UpdatePasswordRequest struct {
CurrentPassword string `json:"current_password"`
NewPassword string `json:"new_password"`
}
type ConfirmAccountDeletionRequest struct {
Token string `json:"token"`
DeletePosts bool `json:"delete_posts"`
}
type RefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." binding:"required"`
}
type RevokeTokenRequest struct {
RefreshToken string `json:"refresh_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." binding:"required"`
}
func NewAuthHandler(authService AuthServiceInterface, userRepo repositories.UserRepository) *AuthHandler {
return &AuthHandler{
authService: authService,
@@ -132,7 +76,7 @@ func NewAuthHandler(authService AuthServiceInterface, userRepo repositories.User
// @Tags auth
// @Accept json
// @Produce json
// @Param request body LoginRequest true "Login credentials"
// @Param request body dto.LoginRequest true "Login credentials"
// @Success 200 {object} AuthTokensResponse "Authentication successful"
// @Failure 400 {object} AuthResponse "Invalid request data or validation failed"
// @Failure 401 {object} AuthResponse "Invalid credentials"
@@ -140,23 +84,15 @@ func NewAuthHandler(authService AuthServiceInterface, userRepo repositories.User
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/login [post]
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
var req struct {
Username string `json:"username"`
Password string `json:"password"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.LoginRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
username := security.SanitizeUsername(req.Username)
password := strings.TrimSpace(req.Password)
if username == "" || password == "" {
SendErrorResponse(w, "Username and password are required", http.StatusBadRequest)
return
}
if err := validation.ValidatePassword(password); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
@@ -175,20 +111,16 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
// @Tags auth
// @Accept json
// @Produce json
// @Param request body RegisterRequest true "Registration data"
// @Param request body dto.RegisterRequest true "Registration data"
// @Success 201 {object} AuthResponse "Registration successful"
// @Failure 400 {object} AuthResponse "Invalid request data or validation failed"
// @Failure 409 {object} AuthResponse "Username or email already exists"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/register [post]
func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
var req struct {
Username string `json:"username"`
Email string `json:"email"`
Password string `json:"password"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.RegisterRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -196,11 +128,6 @@ func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
email := strings.TrimSpace(req.Email)
password := strings.TrimSpace(req.Password)
if username == "" || email == "" || password == "" {
SendErrorResponse(w, "Username, email, and password are required", http.StatusBadRequest)
return
}
username = security.SanitizeUsername(username)
if err := validation.ValidateUsername(username); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
@@ -280,7 +207,7 @@ func (h *AuthHandler) ConfirmEmail(w http.ResponseWriter, r *http.Request) {
// @Tags auth
// @Accept json
// @Produce json
// @Param request body ResendVerificationRequest true "Email address"
// @Param request body dto.ResendVerificationRequest true "Email address"
// @Success 200 {object} AuthResponse
// @Failure 400 {object} AuthResponse
// @Failure 404 {object} AuthResponse
@@ -290,15 +217,14 @@ func (h *AuthHandler) ConfirmEmail(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} AuthResponse
// @Router /api/auth/resend-verification [post]
func (h *AuthHandler) ResendVerificationEmail(w http.ResponseWriter, r *http.Request) {
var req struct {
Email string `json:"email"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.ResendVerificationRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
email := strings.TrimSpace(req.Email)
if email == "" {
SendErrorResponse(w, "Email address is required", http.StatusBadRequest)
return
@@ -359,20 +285,19 @@ func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) {
// @Tags auth
// @Accept json
// @Produce json
// @Param request body ForgotPasswordRequest true "Username or email"
// @Param request body dto.ForgotPasswordRequest true "Username or email"
// @Success 200 {object} AuthResponse "Password reset email sent if account exists"
// @Failure 400 {object} AuthResponse "Invalid request data"
// @Router /api/auth/forgot-password [post]
func (h *AuthHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Request) {
var req struct {
UsernameOrEmail string `json:"username_or_email"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.ForgotPasswordRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
usernameOrEmail := strings.TrimSpace(req.UsernameOrEmail)
if usernameOrEmail == "" {
SendErrorResponse(w, "Username or email is required", http.StatusBadRequest)
return
@@ -389,18 +314,15 @@ func (h *AuthHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Reques
// @Tags auth
// @Accept json
// @Produce json
// @Param request body ResetPasswordRequest true "Password reset data"
// @Param request body dto.ResetPasswordRequest true "Password reset data"
// @Success 200 {object} AuthResponse "Password reset successfully"
// @Failure 400 {object} AuthResponse "Invalid or expired token, or validation failed"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/reset-password [post]
func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
var req struct {
Token string `json:"token"`
NewPassword string `json:"new_password"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.ResetPasswordRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -408,17 +330,12 @@ func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
newPassword := strings.TrimSpace(req.NewPassword)
if token == "" {
SendErrorResponse(w, "Reset token is required", http.StatusBadRequest)
SendErrorResponse(w, "Token is required", http.StatusBadRequest)
return
}
if newPassword == "" {
SendErrorResponse(w, "New password is required", http.StatusBadRequest)
return
}
if len(newPassword) < 8 {
SendErrorResponse(w, "Password must be at least 8 characters long", http.StatusBadRequest)
if err := validation.ValidatePassword(newPassword); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
@@ -443,7 +360,7 @@ func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body UpdateEmailRequest true "New email address"
// @Param request body dto.UpdateEmailRequest true "New email address"
// @Success 200 {object} AuthResponse
// @Failure 400 {object} AuthResponse
// @Failure 401 {object} AuthResponse
@@ -457,11 +374,9 @@ func (h *AuthHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
return
}
var req struct {
Email string `json:"email"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.UpdateEmailRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -498,7 +413,7 @@ func (h *AuthHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body UpdateUsernameRequest true "New username"
// @Param request body dto.UpdateUsernameRequest true "New username"
// @Success 200 {object} AuthResponse
// @Failure 400 {object} AuthResponse
// @Failure 401 {object} AuthResponse
@@ -511,11 +426,9 @@ func (h *AuthHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
return
}
var req struct {
Username string `json:"username"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.UpdateUsernameRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -548,7 +461,7 @@ func (h *AuthHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body UpdatePasswordRequest true "Password update data"
// @Param request body dto.UpdatePasswordRequest true "Password update data"
// @Success 200 {object} AuthResponse
// @Failure 400 {object} AuthResponse
// @Failure 401 {object} AuthResponse
@@ -560,12 +473,9 @@ func (h *AuthHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) {
return
}
var req struct {
CurrentPassword string `json:"current_password"`
NewPassword string `json:"new_password"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.UpdatePasswordRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -633,23 +543,21 @@ func (h *AuthHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
// @Tags auth
// @Accept json
// @Produce json
// @Param request body ConfirmAccountDeletionRequest true "Account deletion data"
// @Param request body dto.ConfirmAccountDeletionRequest true "Account deletion data"
// @Success 200 {object} AuthResponse "Account deleted successfully"
// @Failure 400 {object} AuthResponse "Invalid or expired token"
// @Failure 503 {object} AuthResponse "Email delivery unavailable"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/account/confirm [post]
func (h *AuthHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Request) {
var req struct {
Token string `json:"token"`
DeletePosts bool `json:"delete_posts"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.ConfirmAccountDeletionRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
token := strings.TrimSpace(req.Token)
if token == "" {
SendErrorResponse(w, "Deletion token is required", http.StatusBadRequest)
return
@@ -694,7 +602,7 @@ func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
// @Tags auth
// @Accept json
// @Produce json
// @Param request body RefreshTokenRequest true "Refresh token data"
// @Param request body dto.RefreshTokenRequest true "Refresh token data"
// @Success 200 {object} AuthTokensResponse "Token refreshed successfully"
// @Failure 400 {object} AuthResponse "Invalid request body or missing refresh token"
// @Failure 401 {object} AuthResponse "Invalid or expired refresh token"
@@ -702,13 +610,13 @@ func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/refresh [post]
func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
var req RefreshTokenRequest
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.RefreshTokenRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
if strings.TrimSpace(req.RefreshToken) == "" {
if req.RefreshToken == "" {
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
return
}
@@ -727,20 +635,20 @@ func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body RevokeTokenRequest true "Token revocation data"
// @Param request body dto.RevokeTokenRequest true "Token revocation data"
// @Success 200 {object} AuthResponse "Token revoked successfully"
// @Failure 400 {object} AuthResponse "Invalid request body or missing refresh token"
// @Failure 401 {object} AuthResponse "Invalid or expired access token"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/revoke [post]
func (h *AuthHandler) RevokeToken(w http.ResponseWriter, r *http.Request) {
var req RevokeTokenRequest
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.RevokeTokenRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
if strings.TrimSpace(req.RefreshToken) == "" {
if req.RefreshToken == "" {
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
return
}
@@ -782,28 +690,28 @@ func (h *AuthHandler) RevokeAllTokens(w http.ResponseWriter, r *http.Request) {
func (h *AuthHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
if config.GeneralRateLimit != nil {
rateLimited := config.GeneralRateLimit(r)
rateLimited.Post("/auth/refresh", h.RefreshToken)
rateLimited.Post("/auth/refresh", WithValidation[dto.RefreshTokenRequest](config.ValidationMiddleware, h.RefreshToken))
rateLimited.Get("/auth/confirm", h.ConfirmEmail)
rateLimited.Post("/auth/resend-verification", h.ResendVerificationEmail)
rateLimited.Post("/auth/resend-verification", WithValidation[dto.ResendVerificationRequest](config.ValidationMiddleware, h.ResendVerificationEmail))
} else {
r.Post("/auth/refresh", h.RefreshToken)
r.Post("/auth/refresh", WithValidation[dto.RefreshTokenRequest](config.ValidationMiddleware, h.RefreshToken))
r.Get("/auth/confirm", h.ConfirmEmail)
r.Post("/auth/resend-verification", h.ResendVerificationEmail)
r.Post("/auth/resend-verification", WithValidation[dto.ResendVerificationRequest](config.ValidationMiddleware, h.ResendVerificationEmail))
}
if config.AuthRateLimit != nil {
rateLimited := config.AuthRateLimit(r)
rateLimited.Post("/auth/register", h.Register)
rateLimited.Post("/auth/login", h.Login)
rateLimited.Post("/auth/forgot-password", h.RequestPasswordReset)
rateLimited.Post("/auth/reset-password", h.ResetPassword)
rateLimited.Post("/auth/account/confirm", h.ConfirmAccountDeletion)
rateLimited.Post("/auth/register", WithValidation[dto.RegisterRequest](config.ValidationMiddleware, h.Register))
rateLimited.Post("/auth/login", WithValidation[dto.LoginRequest](config.ValidationMiddleware, h.Login))
rateLimited.Post("/auth/forgot-password", WithValidation[dto.ForgotPasswordRequest](config.ValidationMiddleware, h.RequestPasswordReset))
rateLimited.Post("/auth/reset-password", WithValidation[dto.ResetPasswordRequest](config.ValidationMiddleware, h.ResetPassword))
rateLimited.Post("/auth/account/confirm", WithValidation[dto.ConfirmAccountDeletionRequest](config.ValidationMiddleware, h.ConfirmAccountDeletion))
} else {
r.Post("/auth/register", h.Register)
r.Post("/auth/login", h.Login)
r.Post("/auth/forgot-password", h.RequestPasswordReset)
r.Post("/auth/reset-password", h.ResetPassword)
r.Post("/auth/account/confirm", h.ConfirmAccountDeletion)
r.Post("/auth/register", WithValidation[dto.RegisterRequest](config.ValidationMiddleware, h.Register))
r.Post("/auth/login", WithValidation[dto.LoginRequest](config.ValidationMiddleware, h.Login))
r.Post("/auth/forgot-password", WithValidation[dto.ForgotPasswordRequest](config.ValidationMiddleware, h.RequestPasswordReset))
r.Post("/auth/reset-password", WithValidation[dto.ResetPasswordRequest](config.ValidationMiddleware, h.ResetPassword))
r.Post("/auth/account/confirm", WithValidation[dto.ConfirmAccountDeletionRequest](config.ValidationMiddleware, h.ConfirmAccountDeletion))
}
protected := r
@@ -816,10 +724,10 @@ func (h *AuthHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
protected.Get("/auth/me", h.Me)
protected.Post("/auth/logout", h.Logout)
protected.Post("/auth/revoke", h.RevokeToken)
protected.Post("/auth/revoke", WithValidation[dto.RevokeTokenRequest](config.ValidationMiddleware, h.RevokeToken))
protected.Post("/auth/revoke-all", h.RevokeAllTokens)
protected.Put("/auth/email", h.UpdateEmail)
protected.Put("/auth/username", h.UpdateUsername)
protected.Put("/auth/password", h.UpdatePassword)
protected.Put("/auth/email", WithValidation[dto.UpdateEmailRequest](config.ValidationMiddleware, h.UpdateEmail))
protected.Put("/auth/username", WithValidation[dto.UpdateUsernameRequest](config.ValidationMiddleware, h.UpdateUsername))
protected.Put("/auth/password", WithValidation[dto.UpdatePasswordRequest](config.ValidationMiddleware, h.UpdatePassword))
protected.Delete("/auth/account", h.DeleteAccount)
}

View File

@@ -252,8 +252,8 @@ func TestAuthHandlerLoginSuccess(t *testing.T) {
}
handler := newAuthHandler(repo)
body := bytes.NewBufferString(`{"username":"user","password":"Password123!"}`)
request := httptest.NewRequest(http.MethodPost, "/api/auth/login", body)
bodyStr := `{"username":"user","password":"Password123!"}`
request := createLoginRequest(bodyStr)
recorder := httptest.NewRecorder()
handler.Login(recorder, request)
@@ -274,17 +274,17 @@ func TestAuthHandlerLoginErrors(t *testing.T) {
handler := newAuthHandler(&testutils.UserRepositoryStub{})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString("invalid"))
request := createLoginRequest("invalid")
handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":" ","password":""}`))
request = createLoginRequest(`{"username":" ","password":""}`)
handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"user","password":"WrongPass123!"}`))
request = createLoginRequest(`{"username":"user","password":"WrongPass123!"}`)
handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
@@ -294,7 +294,7 @@ func TestAuthHandlerLoginErrors(t *testing.T) {
}}
handler = newAuthHandler(repo)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"user","password":"Password123!"}`))
request = createLoginRequest(`{"username":"user","password":"Password123!"}`)
handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusForbidden)
@@ -304,7 +304,7 @@ func TestAuthHandlerLoginErrors(t *testing.T) {
handler = newAuthHandler(repo)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"user","password":"Password123!"}`))
request = createLoginRequest(`{"username":"user","password":"Password123!"}`)
handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusInternalServerError)
@@ -330,8 +330,7 @@ func TestAuthHandlerRegisterSuccess(t *testing.T) {
return nil
}})
body := bytes.NewBufferString(`{"username":"newuser","email":"new@example.com","password":"Password123!"}`)
request := httptest.NewRequest(http.MethodPost, "/api/auth/register", body)
request := createRegisterRequest(`{"username":"newuser","email":"new@example.com","password":"Password123!"}`)
recorder := httptest.NewRecorder()
handler.Register(recorder, request)
@@ -354,12 +353,12 @@ func TestAuthHandlerRegisterErrors(t *testing.T) {
handler := newAuthHandler(repo)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString("invalid"))
request := createRegisterRequest("invalid")
handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString(`{"username":"","email":"","password":""}`))
request = createRegisterRequest(`{"username":"","email":"","password":""}`)
handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
@@ -368,7 +367,7 @@ func TestAuthHandlerRegisterErrors(t *testing.T) {
}}
handler = newAuthHandler(repo)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString(`{"username":"new","email":"taken@example.com","password":"Password123!"}`))
request = createRegisterRequest(`{"username":"new","email":"taken@example.com","password":"Password123!"}`)
handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
@@ -382,7 +381,7 @@ func TestAuthHandlerRegisterErrors(t *testing.T) {
}
handler = newAuthHandler(repo)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString(`{"username":"another","email":"taken@example.com","password":"Password123!"}`))
request = createRegisterRequest(`{"username":"another","email":"taken@example.com","password":"Password123!"}`)
handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
}
@@ -477,7 +476,7 @@ func TestAuthHandlerRequestPasswordReset(t *testing.T) {
}})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`{"username_or_email":"user@example.com"}`))
request := createForgotPasswordRequest(`{"username_or_email":"user@example.com"}`)
handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
@@ -495,19 +494,19 @@ func TestAuthHandlerRequestPasswordReset(t *testing.T) {
}})
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`{"username_or_email":"user"}`))
request = createForgotPasswordRequest(`{"username_or_email":"user"}`)
handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`{"username_or_email":""}`))
request = createForgotPasswordRequest(`{"username_or_email":""}`)
handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`invalid json`))
request = createForgotPasswordRequest(`invalid json`)
handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
@@ -518,25 +517,25 @@ func TestAuthHandlerResetPassword(t *testing.T) {
handler := newAuthHandler(repo)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"new_password":"NewPassword123!"}`))
request := createResetPasswordRequest(`{"new_password":"NewPassword123!"}`)
handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"token":"valid_token"}`))
request = createResetPasswordRequest(`{"token":"valid_token"}`)
handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"token":"valid_token","new_password":"short"}`))
request = createResetPasswordRequest(`{"token":"valid_token","new_password":"short"}`)
handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`invalid json`))
request = createResetPasswordRequest(`invalid json`)
handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
@@ -602,7 +601,7 @@ func TestAuthHandlerResetPasswordServiceOutcomes(t *testing.T) {
handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"token":"abc","new_password":"Password123!"}`))
request := createResetPasswordRequest(`{"token":"abc","new_password":"Password123!"}`)
recorder := httptest.NewRecorder()
handler.ResetPassword(recorder, request)
@@ -664,7 +663,7 @@ func TestAuthHandlerUpdateEmail(t *testing.T) {
userID: 1,
mockSetup: func(repo *testutils.UserRepositoryStub) {},
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body",
expectedError: "Invalid request",
},
{
name: "empty email",
@@ -702,7 +701,7 @@ func TestAuthHandlerUpdateEmail(t *testing.T) {
handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPut, "/api/auth/email", bytes.NewBufferString(tt.requestBody))
request := createUpdateEmailRequest(tt.requestBody)
if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx)
@@ -789,7 +788,7 @@ func TestAuthHandlerUpdateUsername(t *testing.T) {
handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPut, "/api/auth/username", bytes.NewBufferString(tt.requestBody))
request := createUpdateUsernameRequest(tt.requestBody)
if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx)
@@ -886,7 +885,7 @@ func TestAuthHandlerUpdatePassword(t *testing.T) {
tt.mockSetup(repo)
handler := newAuthHandler(repo)
request := httptest.NewRequest(http.MethodPut, "/api/auth/password", bytes.NewBufferString(tt.requestBody))
request := createUpdatePasswordRequest(tt.requestBody)
if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx)
@@ -984,8 +983,7 @@ func TestAuthHandlerDeleteAccount(t *testing.T) {
func TestAuthHandlerResendVerificationEmail(t *testing.T) {
makeRequest := func(body string, setup func(*mockAuthService)) (*httptest.ResponseRecorder, AuthResponse) {
request := httptest.NewRequest(http.MethodPost, "/api/auth/resend-verification", bytes.NewBufferString(body))
request = request.WithContext(context.Background())
request := createResendVerificationRequest(body)
repo := &testutils.UserRepositoryStub{}
mockService := &mockAuthService{}
@@ -1014,7 +1012,7 @@ func TestAuthHandlerResendVerificationEmail(t *testing.T) {
name: "invalid json",
body: "not-json",
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body",
expectedError: "Invalid request",
},
{
name: "missing email",
@@ -1139,7 +1137,7 @@ func TestAuthHandlerConfirmAccountDeletion(t *testing.T) {
name: "invalid json",
body: "not-json",
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body",
expectedError: "Invalid request",
},
{
name: "missing token",
@@ -1209,7 +1207,7 @@ func TestAuthHandlerConfirmAccountDeletion(t *testing.T) {
handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPost, "/api/auth/account/confirm", bytes.NewBufferString(tt.body))
request := createConfirmAccountDeletionRequest(tt.body)
recorder := httptest.NewRecorder()
handler.ConfirmAccountDeletion(recorder, request)
@@ -1338,9 +1336,7 @@ func TestAuthHandler_ConcurrentAccess(t *testing.T) {
for i := 0; i < concurrency; i++ {
go func() {
body := bytes.NewBufferString(`{"username":"testuser","password":"Password123!"}`)
req := httptest.NewRequest("POST", "/api/auth/login", body)
req.Header.Set("Content-Type", "application/json")
req := createLoginRequest(`{"username":"testuser","password":"Password123!"}`)
w := httptest.NewRecorder()
handler.Login(w, req)
@@ -1370,8 +1366,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
}, nil
}
body := bytes.NewBufferString(`{"refresh_token":"valid_refresh_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":"valid_refresh_token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1381,8 +1376,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
})
t.Run("Invalid_Request_Body", func(t *testing.T) {
body := bytes.NewBufferString(`invalid json`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`invalid json`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1392,8 +1386,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
})
t.Run("Missing_Refresh_Token", func(t *testing.T) {
body := bytes.NewBufferString(`{"refresh_token":""}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":""}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1407,8 +1400,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, services.ErrRefreshTokenExpired
}
body := bytes.NewBufferString(`{"refresh_token":"expired_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":"expired_token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1422,8 +1414,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, services.ErrRefreshTokenInvalid
}
body := bytes.NewBufferString(`{"refresh_token":"invalid_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":"invalid_token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1437,8 +1428,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, services.ErrAccountLocked
}
body := bytes.NewBufferString(`{"refresh_token":"locked_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":"locked_token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1452,8 +1442,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, fmt.Errorf("internal error")
}
body := bytes.NewBufferString(`{"refresh_token":"error_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":"error_token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1473,8 +1462,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
return nil
}
body := bytes.NewBufferString(`{"refresh_token":"token_to_revoke"}`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req := createRevokeTokenRequest(`{"refresh_token":"token_to_revoke"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1484,8 +1472,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
})
t.Run("Invalid_Request_Body", func(t *testing.T) {
body := bytes.NewBufferString(`invalid json`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req := createRevokeTokenRequest(`invalid json`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1495,8 +1482,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
})
t.Run("Missing_Refresh_Token", func(t *testing.T) {
body := bytes.NewBufferString(`{"refresh_token":""}`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req := createRevokeTokenRequest(`{"refresh_token":""}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1510,8 +1496,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
return fmt.Errorf("revoke failed")
}
body := bytes.NewBufferString(`{"refresh_token":"token"}`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req := createRevokeTokenRequest(`{"refresh_token":"token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"net/url"
"reflect"
"strconv"
"strings"
"time"
@@ -290,3 +291,24 @@ func HandleServiceError(w http.ResponseWriter, err error, defaultMsg string, def
SendErrorResponse(w, defaultMsg, defaultCode)
return false
}
func GetValidatedDTO[T any](r *http.Request) (*T, bool) {
dtoVal := middleware.GetValidatedDTOFromContext(r.Context())
if dtoVal == nil {
return nil, false
}
dto, ok := dtoVal.(*T)
return dto, ok
}
func WithValidation[T any](validationMiddleware func(http.Handler) http.Handler, handler http.HandlerFunc) http.HandlerFunc {
if validationMiddleware == nil {
return handler
}
var zero T
dtoType := reflect.TypeOf(zero)
return func(w http.ResponseWriter, r *http.Request) {
ctx := middleware.SetDTOTypeInContext(r.Context(), dtoType)
validationMiddleware(handler).ServeHTTP(w, r.WithContext(ctx))
}
}

View File

@@ -1,6 +1,7 @@
package handlers
import (
"bytes"
"context"
"encoding/json"
"errors"
@@ -11,6 +12,7 @@ import (
"testing"
"goyco/internal/database"
"goyco/internal/dto"
"goyco/internal/middleware"
"goyco/internal/services"
"goyco/internal/testutils"
@@ -721,6 +723,74 @@ func TestDecodeJSONRequest(t *testing.T) {
}
}
func createRequestWithDTO[T any](method, url string, body []byte) *http.Request {
r := httptest.NewRequest(method, url, bytes.NewReader(body))
var dto T
if len(body) > 0 {
if err := json.NewDecoder(bytes.NewReader(body)).Decode(&dto); err != nil {
return r
}
}
ctx := middleware.SetValidatedDTOInContext(r.Context(), &dto)
return r.WithContext(ctx)
}
func createLoginRequest(body string) *http.Request {
return createRequestWithDTO[dto.LoginRequest](http.MethodPost, "/api/auth/login", []byte(body))
}
func createRegisterRequest(body string) *http.Request {
return createRequestWithDTO[dto.RegisterRequest](http.MethodPost, "/api/auth/register", []byte(body))
}
func createResendVerificationRequest(body string) *http.Request {
return createRequestWithDTO[dto.ResendVerificationRequest](http.MethodPost, "/api/auth/resend-verification", []byte(body))
}
func createForgotPasswordRequest(body string) *http.Request {
return createRequestWithDTO[dto.ForgotPasswordRequest](http.MethodPost, "/api/auth/forgot-password", []byte(body))
}
func createResetPasswordRequest(body string) *http.Request {
return createRequestWithDTO[dto.ResetPasswordRequest](http.MethodPost, "/api/auth/reset-password", []byte(body))
}
func createUpdateEmailRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdateEmailRequest](http.MethodPut, "/api/auth/email", []byte(body))
}
func createUpdateUsernameRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdateUsernameRequest](http.MethodPut, "/api/auth/username", []byte(body))
}
func createUpdatePasswordRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdatePasswordRequest](http.MethodPut, "/api/auth/password", []byte(body))
}
func createConfirmAccountDeletionRequest(body string) *http.Request {
return createRequestWithDTO[dto.ConfirmAccountDeletionRequest](http.MethodPost, "/api/auth/account/confirm", []byte(body))
}
func createRefreshTokenRequest(body string) *http.Request {
return createRequestWithDTO[dto.RefreshTokenRequest](http.MethodPost, "/api/auth/refresh", []byte(body))
}
func createRevokeTokenRequest(body string) *http.Request {
return createRequestWithDTO[dto.RevokeTokenRequest](http.MethodPost, "/api/auth/revoke", []byte(body))
}
func createCreatePostRequest(body string) *http.Request {
return createRequestWithDTO[dto.CreatePostRequest](http.MethodPost, "/api/posts", []byte(body))
}
func createUpdatePostRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdatePostRequest](http.MethodPut, "/api/posts/1", []byte(body))
}
func createVoteRequest(body string) *http.Request {
return createRequestWithDTO[dto.CastVoteRequest](http.MethodPost, "/api/posts/1/vote", []byte(body))
}
func TestParsePagination(t *testing.T) {
tests := []struct {
name string

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"html/template"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
@@ -877,7 +878,8 @@ func (h *PageHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
func (h *PageHandler) Settings(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r)
if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther)
redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return
}
@@ -897,7 +899,8 @@ func (h *PageHandler) Settings(w http.ResponseWriter, r *http.Request) {
func (h *PageHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r)
if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther)
redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return
}
@@ -960,13 +963,15 @@ func (h *PageHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
}
h.clearAuthCookie(w, r)
http.Redirect(w, r, "/login?flash=Email updated. Check your inbox to confirm the new address. You will need to sign in again after verification.", http.StatusSeeOther)
redirectURL := "/login?flash=" + url.QueryEscape("Email updated. Check your inbox to confirm the new address. You will need to sign in again after verification.")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
func (h *PageHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r)
if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther)
redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return
}
@@ -1022,13 +1027,15 @@ func (h *PageHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
return
}
http.Redirect(w, r, "/settings?flash=Username updated successfully.", http.StatusSeeOther)
redirectURL := "/settings?flash=" + url.QueryEscape("Username updated successfully.")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
func (h *PageHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r)
if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther)
redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return
}
@@ -1140,13 +1147,15 @@ func (h *PageHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) {
return
}
http.Redirect(w, r, "/settings?flash=Password updated successfully.", http.StatusSeeOther)
redirectURL := "/settings?flash=" + url.QueryEscape("Password updated successfully.")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
func (h *PageHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r)
if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther)
redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return
}
@@ -1204,7 +1213,8 @@ func (h *PageHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
return
}
http.Redirect(w, r, "/settings?flash=Check your inbox for a confirmation link to finish deleting your account.", http.StatusSeeOther)
redirectURL := "/settings?flash=" + url.QueryEscape("Check your inbox for a confirmation link to finish deleting your account.")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
func (h *PageHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Request) {
@@ -1328,7 +1338,8 @@ func (h *PageHandler) clearAuthCookie(w http.ResponseWriter, r *http.Request) {
func (h *PageHandler) Vote(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r)
if user == nil {
http.Redirect(w, r, "/login?flash=Please sign in to vote", http.StatusSeeOther)
redirectURL := "/login?flash=" + url.QueryEscape("Please sign in to vote")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return
}

View File

@@ -36,11 +36,6 @@ func NewPostHandler(postRepo repositories.PostRepository, titleFetcher services.
type PostResponse = CommonResponse
type UpdatePostRequest struct {
Title string `json:"title"`
Content string `json:"content"`
}
// @Summary Get posts
// @Description Get a list of posts with pagination. Posts include vote statistics (up_votes, down_votes, score) and current user's vote status.
// @Tags posts
@@ -111,7 +106,7 @@ func (h *PostHandler) GetPost(w http.ResponseWriter, r *http.Request) {
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body CreatePostRequest true "Post data"
// @Param request body dto.CreatePostRequest true "Post data"
// @Success 201 {object} PostResponse
// @Failure 400 {object} PostResponse "Invalid request data or validation failed"
// @Failure 401 {object} PostResponse "Authentication required"
@@ -120,32 +115,9 @@ func (h *PostHandler) GetPost(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} PostResponse "Internal server error"
// @Router /api/posts [post]
func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
var req struct {
Title string `json:"title"`
URL string `json:"url"`
Content string `json:"content"`
}
if !DecodeJSONRequest(w, r, &req) {
return
}
req.Title = security.SanitizeInput(req.Title)
req.URL = security.SanitizeURL(req.URL)
req.Content = security.SanitizePostContent(req.Content)
if req.URL == "" {
SendErrorResponse(w, "URL is required", http.StatusBadRequest)
return
}
if len(req.Title) > 200 {
SendErrorResponse(w, "Title must be no more than 200 characters", http.StatusBadRequest)
return
}
if len(req.Content) > 10000 {
SendErrorResponse(w, "Content must be no more than 10,000 characters", http.StatusBadRequest)
req, ok := GetValidatedDTO[dto.CreatePostRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -154,13 +126,20 @@ func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
return
}
title := req.Title
title := security.SanitizeInput(req.Title)
url := security.SanitizeURL(req.URL)
content := security.SanitizePostContent(req.Content)
if url == "" {
SendErrorResponse(w, "Invalid URL", http.StatusBadRequest)
return
}
if title == "" && h.titleFetcher != nil {
titleCtx, cancel := context.WithTimeout(r.Context(), 7*time.Second)
defer cancel()
fetchedTitle, err := h.titleFetcher.FetchTitle(titleCtx, req.URL)
fetchedTitle, err := h.titleFetcher.FetchTitle(titleCtx, url)
if err != nil {
switch {
case errors.Is(err, services.ErrUnsupportedScheme):
@@ -186,10 +165,20 @@ func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
return
}
if len(title) > 200 {
SendErrorResponse(w, "Title must be at most 200 characters", http.StatusBadRequest)
return
}
if len(content) > 10000 {
SendErrorResponse(w, "Content must be at most 10000 characters", http.StatusBadRequest)
return
}
post := &database.Post{
Title: title,
URL: req.URL,
Content: req.Content,
URL: url,
Content: content,
AuthorID: &userID,
}
@@ -257,7 +246,7 @@ func (h *PostHandler) SearchPosts(w http.ResponseWriter, r *http.Request) {
// @Produce json
// @Security BearerAuth
// @Param id path int true "Post ID"
// @Param request body UpdatePostRequest true "Post update data"
// @Param request body dto.UpdatePostRequest true "Post update data"
// @Success 200 {object} PostResponse "Post updated successfully"
// @Failure 400 {object} PostResponse "Invalid request data or validation failed"
// @Failure 401 {object} PostResponse "Authentication required"
@@ -286,40 +275,27 @@ func (h *PostHandler) UpdatePost(w http.ResponseWriter, r *http.Request) {
return
}
var req struct {
Title string `json:"title"`
Content string `json:"content"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.UpdatePostRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
req.Title = security.SanitizeInput(req.Title)
req.Content = security.SanitizePostContent(req.Content)
title := security.SanitizeInput(req.Title)
content := security.SanitizePostContent(req.Content)
if len(req.Title) > 200 {
SendErrorResponse(w, "Title must be no more than 200 characters", http.StatusBadRequest)
return
}
if len(req.Content) > 10000 {
SendErrorResponse(w, "Content must be no more than 10,000 characters", http.StatusBadRequest)
return
}
if err := validation.ValidateTitle(req.Title); err != nil {
if err := validation.ValidateTitle(title); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
if err := validation.ValidateContent(req.Content); err != nil {
if err := validation.ValidateContent(content); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
post.Title = req.Title
post.Content = req.Content
post.Title = title
post.Content = content
if err := h.postRepo.Update(post); err != nil {
SendErrorResponse(w, "Failed to update post", http.StatusInternalServerError)
@@ -458,7 +434,7 @@ func (h *PostHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
if config.GeneralRateLimit != nil {
protected = config.GeneralRateLimit(protected)
}
protected.Post("/posts", h.CreatePost)
protected.Put("/posts/{id}", h.UpdatePost)
protected.Post("/posts", WithValidation[dto.CreatePostRequest](config.ValidationMiddleware, h.CreatePost))
protected.Put("/posts/{id}", WithValidation[dto.UpdatePostRequest](config.ValidationMiddleware, h.UpdatePost))
protected.Delete("/posts/{id}", h.DeletePost)
}

View File

@@ -69,9 +69,8 @@ func TestPostHandlerCreatePostWithTitleFetcher(t *testing.T) {
handler := NewPostHandler(repo, titleFetcher, nil)
request := httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"url":"https://example.com","content":"Test content"}`))
request := createCreatePostRequest(`{"url":"https://example.com","content":"Test content"}`)
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.CreatePost(recorder, request)
@@ -171,7 +170,7 @@ func TestPostHandlerUpdatePostUnauthorized(t *testing.T) {
handler := NewPostHandler(repo, nil, nil)
request := httptest.NewRequest(http.MethodPut, "/api/posts/1", bytes.NewBufferString(`{"title":"Updated Title","content":"Updated content"}`))
request := createUpdatePostRequest(`{"title":"Updated Title","content":"Updated content"}`)
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
request.Header.Set("Content-Type", "application/json")
@@ -278,8 +277,7 @@ func TestPostHandlerCreatePostSuccess(t *testing.T) {
handler := NewPostHandler(repo, fetcher, nil)
body := bytes.NewBufferString(`{"title":" ","url":"https://example.com","content":"Go"}`)
request := httptest.NewRequest(http.MethodPost, "/api/posts", body)
request := createCreatePostRequest(`{"title":" ","url":"https://example.com","content":"Go"}`)
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(42))
request = request.WithContext(ctx)
@@ -297,7 +295,7 @@ func TestPostHandlerCreatePostValidation(t *testing.T) {
handler := NewPostHandler(testutils.NewPostRepositoryStub(), &testutils.TitleFetcherStub{}, nil)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"title":"","url":"","content":""}`))
request := createCreatePostRequest(`{"title":"","url":"","content":""}`)
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
handler.CreatePost(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
@@ -305,14 +303,14 @@ func TestPostHandlerCreatePostValidation(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`invalid json`))
request = createCreatePostRequest(`invalid json`)
handler.CreatePost(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid JSON, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"title":"ok","url":"https://example.com"}`))
request = createCreatePostRequest(`{"title":"ok","url":"https://example.com"}`)
handler.CreatePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
}
@@ -336,8 +334,7 @@ func TestPostHandlerCreatePostTitleFetcherErrors(t *testing.T) {
return "", tc.err
}}
handler := NewPostHandler(repo, fetcher, nil)
body := bytes.NewBufferString(`{"title":" ","url":"https://example.com"}`)
request := httptest.NewRequest(http.MethodPost, "/api/posts", body)
request := createCreatePostRequest(`{"title":" ","url":"https://example.com"}`)
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder()
@@ -495,7 +492,7 @@ func TestPostHandlerUpdatePost(t *testing.T) {
}
handler := NewPostHandler(repo, &testutils.TitleFetcherStub{}, nil)
request := httptest.NewRequest(http.MethodPut, "/api/posts/"+tt.postID, bytes.NewBufferString(tt.requestBody))
request := createUpdatePostRequest(tt.requestBody)
if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx)
@@ -663,12 +660,15 @@ func (e *errorVoteRepository) GetByVoteHash(string) (*database.Vote, error) {
func (e *errorVoteRepository) GetByPostID(uint) ([]database.Vote, error) {
return nil, errors.New("database error")
}
func (e *errorVoteRepository) GetByUserID(uint) ([]database.Vote, error) { return nil, nil }
func (e *errorVoteRepository) Update(*database.Vote) error { return nil }
func (e *errorVoteRepository) Delete(uint) error { return nil }
func (e *errorVoteRepository) Count() (int64, error) { return 0, nil }
func (e *errorVoteRepository) CountByPostID(uint) (int64, error) { return 0, nil }
func (e *errorVoteRepository) CountByUserID(uint) (int64, error) { return 0, nil }
func (e *errorVoteRepository) GetByUserID(uint) ([]database.Vote, error) { return nil, nil }
func (e *errorVoteRepository) Update(*database.Vote) error { return nil }
func (e *errorVoteRepository) Delete(uint) error { return nil }
func (e *errorVoteRepository) Count() (int64, error) { return 0, nil }
func (e *errorVoteRepository) CountByPostID(uint) (int64, error) { return 0, nil }
func (e *errorVoteRepository) CountByUserID(uint) (int64, error) { return 0, nil }
func (e *errorVoteRepository) GetVoteCountsByPostID(uint) (int, int, error) {
return 0, 0, errors.New("database error")
}
func (e *errorVoteRepository) WithTx(*gorm.DB) repositories.VoteRepository { return e }
func TestPostHandler_EdgeCases(t *testing.T) {

View File

@@ -13,9 +13,10 @@ type RouteModule interface {
}
type RouteModuleConfig struct {
AuthService middleware.TokenVerifier
GeneralRateLimit func(chi.Router) chi.Router
AuthRateLimit func(chi.Router) chi.Router
CSRFMiddleware func(http.Handler) http.Handler
AuthMiddleware func(http.Handler) http.Handler
AuthService middleware.TokenVerifier
GeneralRateLimit func(chi.Router) chi.Router
AuthRateLimit func(chi.Router) chi.Router
CSRFMiddleware func(http.Handler) http.Handler
AuthMiddleware func(http.Handler) http.Handler
ValidationMiddleware func(http.Handler) http.Handler
}

View File

@@ -24,10 +24,6 @@ func TestPostHandler_XSSProtection_Comprehensive(t *testing.T) {
t.Run("XSS_"+payload[:minLen(20, len(payload))], func(t *testing.T) {
repo := &testutils.PostRepositoryStub{
CreateFn: func(post *database.Post) error {
sanitizedTitle := security.SanitizeInput(payload)
if post.Title != sanitizedTitle {
t.Errorf("Expected sanitized title, got %q", post.Title)
}
return nil
},
}
@@ -41,14 +37,46 @@ func TestPostHandler_XSSProtection_Comprehensive(t *testing.T) {
}
body, _ := json.Marshal(postData)
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
request := createCreatePostRequest(string(body))
request.Header.Set("Content-Type", "application/json")
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder()
handler.CreatePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusCreated)
if recorder.Code != http.StatusCreated {
t.Errorf("Expected status %d, got %d. Body: %s", http.StatusCreated, recorder.Code, recorder.Body.String())
return
}
var response CommonResponse
if err := json.NewDecoder(recorder.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if !response.Success {
t.Errorf("Expected successful response, got error: %s", response.Error)
return
}
dataMap, ok := response.Data.(map[string]any)
if !ok {
t.Fatalf("Expected data to be a map, got %T", response.Data)
}
title, ok := dataMap["title"].(string)
if !ok {
t.Fatalf("Expected title to be a string, got %T", dataMap["title"])
}
expectedSanitized := security.SanitizeInput(payload)
if title != expectedSanitized {
t.Errorf("Expected sanitized title %q, got %q", expectedSanitized, title)
}
if title == payload {
t.Errorf("Title was not sanitized - original payload %q matches response %q", payload, title)
}
})
}
}
@@ -123,7 +151,7 @@ func TestPostHandler_InputValidation(t *testing.T) {
}
body, _ := json.Marshal(postData)
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
request := createCreatePostRequest(string(body))
request.Header.Set("Content-Type", "application/json")
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder()
@@ -230,7 +258,7 @@ func TestAuthHandler_PasswordValidation(t *testing.T) {
}
body, _ := json.Marshal(registerData)
request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
request := createRegisterRequest(string(body))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
@@ -290,7 +318,7 @@ func TestAuthHandler_UsernameSanitization(t *testing.T) {
}
body, _ := json.Marshal(registerData)
request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
request := createRegisterRequest(string(body))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()

View File

@@ -91,7 +91,7 @@ func (h *UserHandler) GetUser(w http.ResponseWriter, r *http.Request) {
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body RegisterRequest true "User data"
// @Param request body dto.RegisterRequest true "User data"
// @Success 201 {object} UserResponse "User created successfully"
// @Failure 400 {object} UserResponse "Invalid request data or validation failed"
// @Failure 401 {object} UserResponse "Authentication required"
@@ -99,13 +99,9 @@ func (h *UserHandler) GetUser(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} UserResponse "Internal server error"
// @Router /api/users [post]
func (h *UserHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
var req struct {
Username string `json:"username"`
Email string `json:"email"`
Password string `json:"password"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.RegisterRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -189,7 +185,7 @@ func (h *UserHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
}
protected.Get("/users", h.GetUsers)
protected.Post("/users", h.CreateUser)
protected.Post("/users", WithValidation[dto.RegisterRequest](config.ValidationMiddleware, h.CreateUser))
protected.Get("/users/{id}", h.GetUser)
protected.Get("/users/{id}/posts", h.GetUserPosts)
}

View File

@@ -1,7 +1,6 @@
package handlers
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
@@ -103,7 +102,7 @@ func TestUserHandlerCreateUser(t *testing.T) {
return nil
}})
request := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"user","email":"user@example.com","password":"Password123!"}`))
request := createRegisterRequest(`{"username":"user","email":"user@example.com","password":"Password123!"}`)
recorder := httptest.NewRecorder()
handler.CreateUser(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusCreated)
@@ -126,14 +125,14 @@ func TestUserHandlerCreateUser(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString("invalid"))
request = createRegisterRequest("invalid")
handler.CreateUser(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid json, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"","email":"","password":""}`))
request = createRegisterRequest(`{"username":"","email":"","password":""}`)
handler.CreateUser(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for missing fields, got %d", recorder.Result().StatusCode)
@@ -144,7 +143,7 @@ func TestUserHandlerCreateUser(t *testing.T) {
}
handler = newUserHandler(repo)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"user","email":"user@example.com","password":"Password123!"}`))
request = createRegisterRequest(`{"username":"user","email":"user@example.com","password":"Password123!"}`)
handler.CreateUser(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
}
@@ -350,7 +349,7 @@ func TestUserHandler_PasswordValidation(t *testing.T) {
handler := NewUserHandler(repo, authService)
requestBody := fmt.Sprintf(`{"username":"testuser","email":"test@example.com","password":"%s"}`, tt.password)
request := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(requestBody))
request := createRegisterRequest(requestBody)
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()

View File

@@ -4,6 +4,7 @@ import (
"net/http"
"goyco/internal/database"
"goyco/internal/dto"
"goyco/internal/services"
"github.com/go-chi/chi/v5"
@@ -39,11 +40,6 @@ func NewVoteHandler(voteService *services.VoteService) *VoteHandler {
}
}
// @Description Vote request with type field. All votes are handled the same way.
type VoteRequest struct {
Type string `json:"type" example:"up" enums:"up,down,none" description:"Vote type: 'up' for upvote, 'down' for downvote, 'none' to remove vote"`
}
type VoteResponse = CommonResponse
// @Summary Cast a vote on a post
@@ -62,7 +58,7 @@ type VoteResponse = CommonResponse
// @Produce json
// @Security BearerAuth
// @Param id path int true "Post ID"
// @Param request body VoteRequest true "Vote data (type: 'up', 'down', or 'none' to remove)"
// @Param request body dto.CastVoteRequest true "Vote data (type: 'up', 'down', or 'none' to remove)"
// @Success 200 {object} VoteResponse "Vote cast successfully with updated post statistics"
// @Failure 401 {object} VoteResponse "Authentication required"
// @Failure 400 {object} VoteResponse "Invalid request data or vote type"
@@ -82,8 +78,9 @@ func (h *VoteHandler) CastVote(w http.ResponseWriter, r *http.Request) {
return
}
var req VoteRequest
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.CastVoteRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -286,7 +283,7 @@ func (h *VoteHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
protected = config.GeneralRateLimit(protected)
}
protected.Post("/posts/{id}/vote", h.CastVote)
protected.Post("/posts/{id}/vote", WithValidation[dto.CastVoteRequest](config.ValidationMiddleware, h.CastVote))
protected.Delete("/posts/{id}/vote", h.RemoveVote)
protected.Get("/posts/{id}/vote", h.GetUserVote)
protected.Get("/posts/{id}/votes", h.GetPostVotes)

View File

@@ -1,7 +1,6 @@
package handlers
import (
"bytes"
"context"
"encoding/json"
"fmt"
@@ -59,13 +58,13 @@ func TestVoteHandlerCastVote(t *testing.T) {
handler := newVoteHandlerWithRepos()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
handler.CastVote(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/abc/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "abc"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -73,7 +72,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`invalid`))
request = createVoteRequest(`invalid`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -83,7 +82,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"maybe"}`))
request = createVoteRequest(`{"type":"maybe"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -93,7 +92,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -101,7 +100,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx)
@@ -111,7 +110,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"none"}`))
request = createVoteRequest(`{"type":"none"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3))
request = request.WithContext(ctx)
@@ -125,7 +124,7 @@ func TestVoteHandlerCastVotePostNotFound(t *testing.T) {
handler, _, posts := newVoteHandlerWithReposRefs()
delete(posts, 1)
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -164,7 +163,7 @@ func TestVoteHandlerRemoveVote(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -202,7 +201,7 @@ func TestVoteHandlerRemoveVotePostNotFound(t *testing.T) {
func TestVoteHandlerRemoveVoteUnexpectedError(t *testing.T) {
handler, voteRepo, _ := newVoteHandlerWithReposRefs()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -257,7 +256,7 @@ func TestVoteHandlerGetUserVote(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -301,7 +300,7 @@ func TestVoteHandlerGetPostVotes(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -311,7 +310,7 @@ func TestVoteHandlerGetPostVotes(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx)
@@ -345,7 +344,7 @@ func TestVoteFlowRegression(t *testing.T) {
postID := "1"
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx)
@@ -363,7 +362,7 @@ func TestVoteFlowRegression(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx)
@@ -373,7 +372,7 @@ func TestVoteFlowRegression(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"none"}`))
request = createVoteRequest(`{"type":"none"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx)
@@ -404,7 +403,7 @@ func TestVoteFlowRegression(t *testing.T) {
postID := "1"
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -414,7 +413,7 @@ func TestVoteFlowRegression(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx)
@@ -424,7 +423,7 @@ func TestVoteFlowRegression(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3))
request = request.WithContext(ctx)
@@ -452,7 +451,7 @@ func TestVoteFlowRegression(t *testing.T) {
t.Run("ErrorHandlingEdgeCases", func(t *testing.T) {
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(``))
request := createVoteRequest(``)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -460,7 +459,7 @@ func TestVoteFlowRegression(t *testing.T) {
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{}`))
request = createVoteRequest(`{}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -470,7 +469,7 @@ func TestVoteFlowRegression(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"invalid"}`))
request = createVoteRequest(`{"type":"invalid"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)

View File

@@ -53,25 +53,25 @@ func TestIntegration_Caching(t *testing.T) {
router := ctx.Router
t.Run("Cache_Hit_On_Repeated_Requests", func(t *testing.T) {
req1 := httptest.NewRequest("GET", "/api/posts", nil)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
firstRequest := httptest.NewRequest("GET", "/api/posts", nil)
firstRecorder := httptest.NewRecorder()
router.ServeHTTP(firstRecorder, firstRequest)
time.Sleep(10 * time.Millisecond)
req2 := httptest.NewRequest("GET", "/api/posts", nil)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
secondRequest := httptest.NewRequest("GET", "/api/posts", nil)
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondRequest)
if rec1.Code != rec2.Code {
if firstRecorder.Code != secondRecorder.Code {
t.Error("Cached responses should have same status code")
}
if rec1.Body.String() != rec2.Body.String() {
if firstRecorder.Body.String() != secondRecorder.Body.String() {
t.Error("Cached responses should have same body")
}
if rec2.Header().Get("X-Cache") != "HIT" {
if secondRecorder.Header().Get("X-Cache") != "HIT" {
t.Log("Cache may not be enabled for this path or response may not be cacheable")
}
})
@@ -80,9 +80,9 @@ func TestIntegration_Caching(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createUserWithCleanup(t, ctx, "cache_post_user", "cache_post@example.com")
req1 := httptest.NewRequest("GET", "/api/posts", nil)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
firstRequest := httptest.NewRequest("GET", "/api/posts", nil)
firstRecorder := httptest.NewRecorder()
router.ServeHTTP(firstRecorder, firstRequest)
time.Sleep(10 * time.Millisecond)
@@ -92,12 +92,12 @@ func TestIntegration_Caching(t *testing.T) {
"content": "Test content",
}
body, _ := json.Marshal(postBody)
req2 := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req2.Header.Set("Content-Type", "application/json")
req2.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
secondRequest := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
secondRequest.Header.Set("Content-Type", "application/json")
secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondRequest)
time.Sleep(10 * time.Millisecond)
@@ -105,17 +105,17 @@ func TestIntegration_Caching(t *testing.T) {
rec3 := httptest.NewRecorder()
router.ServeHTTP(rec3, req3)
if rec1.Body.String() == rec3.Body.String() && rec1.Code == http.StatusOK && rec3.Code == http.StatusOK {
if firstRecorder.Body.String() == rec3.Body.String() && firstRecorder.Code == http.StatusOK && rec3.Code == http.StatusOK {
t.Log("Cache invalidation may not be working or cache may not be enabled")
}
})
t.Run("Cache_Headers_Present", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("GET", "/api/posts", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if rec.Header().Get("Cache-Control") == "" && rec.Header().Get("X-Cache") == "" {
if recorder.Header().Get("Cache-Control") == "" && recorder.Header().Get("X-Cache") == "" {
t.Log("Cache headers may not be present for all responses")
}
})
@@ -126,18 +126,18 @@ func TestIntegration_Caching(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Cache Delete Post", "https://example.com/cache-delete")
req1 := httptest.NewRequest("GET", "/api/posts", nil)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
firstRequest := httptest.NewRequest("GET", "/api/posts", nil)
firstRecorder := httptest.NewRecorder()
router.ServeHTTP(firstRecorder, firstRequest)
time.Sleep(10 * time.Millisecond)
req2 := httptest.NewRequest("DELETE", "/api/posts/"+fmt.Sprintf("%d", post.ID), nil)
req2.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID)
req2 = testutils.WithURLParams(req2, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
secondRequest := httptest.NewRequest("DELETE", "/api/posts/"+fmt.Sprintf("%d", post.ID), nil)
secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
secondRequest = testutils.WithURLParams(secondRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondRequest)
time.Sleep(10 * time.Millisecond)
@@ -145,7 +145,7 @@ func TestIntegration_Caching(t *testing.T) {
rec3 := httptest.NewRecorder()
router.ServeHTTP(rec3, req3)
if rec1.Body.String() == rec3.Body.String() && rec1.Code == http.StatusOK && rec3.Code == http.StatusOK {
if firstRecorder.Body.String() == rec3.Body.String() && firstRecorder.Code == http.StatusOK && rec3.Code == http.StatusOK {
t.Log("Cache invalidation may not be working or cache may not be enabled")
}
})

View File

@@ -1,15 +1,14 @@
package integration
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"goyco/internal/middleware"
"goyco/internal/services"
"goyco/internal/testutils"
)
@@ -20,17 +19,8 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "logout_user", "logout@example.com")
reqBody := map[string]string{}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/logout", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
request := makePostRequest(t, ctx.Router, "/api/auth/logout", map[string]any{}, user, nil)
assertStatus(t, request, http.StatusOK)
})
t.Run("Auth_Revoke_Token_Endpoint", func(t *testing.T) {
@@ -42,52 +32,23 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
t.Fatalf("Failed to login: %v", err)
}
reqBody := map[string]string{
"refresh_token": loginResult.RefreshToken,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/revoke", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
request := makePostRequest(t, ctx.Router, "/api/auth/revoke", map[string]any{"refresh_token": loginResult.RefreshToken}, user, nil)
assertStatus(t, request, http.StatusOK)
})
t.Run("Auth_Revoke_All_Tokens_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "revoke_all_user", "revoke_all@example.com")
reqBody := map[string]string{}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/revoke-all", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
request := makePostRequest(t, ctx.Router, "/api/auth/revoke-all", map[string]any{}, user, nil)
assertStatus(t, request, http.StatusOK)
})
t.Run("Auth_Resend_Verification_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{
"email": "resend@example.com",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/resend-verification", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatusRange(t, rec, http.StatusOK, http.StatusNotFound)
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/resend-verification", map[string]any{"email": "resend@example.com"})
assertStatusRange(t, request, http.StatusOK, http.StatusNotFound)
})
t.Run("Auth_Confirm_Email_Endpoint", func(t *testing.T) {
@@ -99,36 +60,20 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
token = "test-token"
}
req := httptest.NewRequest("GET", "/api/auth/confirm?token="+url.QueryEscape(token), nil)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatusRange(t, rec, http.StatusOK, http.StatusBadRequest)
request := makeGetRequest(t, ctx.Router, "/api/auth/confirm?token="+url.QueryEscape(token))
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Auth_Update_Email_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "update_email_api_user", "update_email_api@example.com")
reqBody := map[string]string{
"email": "newemail@example.com",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/auth/email", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makePutRequest(t, ctx.Router, "/api/auth/email", map[string]any{"email": "newemail@example.com"}, user, nil)
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if email, ok := data["email"].(string); ok && email != "newemail@example.com" {
t.Errorf("Expected email to be updated, got %s", email)
}
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if email, ok := data["email"].(string); ok && email != "newemail@example.com" {
t.Errorf("Expected email to be updated, got %s", email)
}
}
})
@@ -137,24 +82,12 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "update_username_api_user", "update_username_api@example.com")
reqBody := map[string]string{
"username": "new_username",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/auth/username", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makePutRequest(t, ctx.Router, "/api/auth/username", map[string]any{"username": "new_username"}, user, nil)
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if username, ok := data["username"].(string); ok && username != "new_username" {
t.Errorf("Expected username to be updated, got %s", username)
}
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if username, ok := data["username"].(string); ok && username != "new_username" {
t.Errorf("Expected username to be updated, got %s", username)
}
}
})
@@ -163,19 +96,12 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "users_list_user", "users_list@example.com")
req := httptest.NewRequest("GET", "/api/users", nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, ctx.Router, "/api/users", user, nil)
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["users"]; !exists {
t.Error("Expected users in response")
}
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if _, exists := data["users"]; !exists {
t.Error("Expected users in response")
}
}
})
@@ -184,21 +110,13 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "users_get_user", "users_get@example.com")
req := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d", user.User.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if userData, ok := data["user"].(map[string]any); ok {
if id, ok := userData["id"].(float64); ok && uint(id) != user.User.ID {
t.Errorf("Expected user ID %d, got %.0f", user.User.ID, id)
}
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if userData, ok := data["user"].(map[string]any); ok {
if id, ok := userData["id"].(float64); ok && uint(id) != user.User.ID {
t.Errorf("Expected user ID %d, got %.0f", user.User.ID, id)
}
}
}
@@ -210,24 +128,16 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "User Posts Test", "https://example.com/user-posts")
req := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d/posts", user.User.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d/posts", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) == 0 {
t.Error("Expected at least one post in response")
}
} else {
t.Error("Expected posts array in response")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) == 0 {
t.Error("Expected at least one post in response")
}
} else {
t.Error("Expected posts array in response")
}
}
})
@@ -236,26 +146,16 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "users_create_admin", "users_create_admin@example.com")
reqBody := map[string]string{
request := makePostRequest(t, ctx.Router, "/api/users", map[string]any{
"username": "created_user",
"email": "created@example.com",
"password": "SecurePass123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/users", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
}, user, nil)
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusCreated)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["user"]; !exists {
t.Error("Expected user in response")
}
response := assertJSONResponse(t, request, http.StatusCreated)
if data, ok := getDataFromResponse(response); ok {
if _, exists := data["user"]; !exists {
t.Error("Expected user in response")
}
}
})
@@ -266,27 +166,16 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Update Test Post", "https://example.com/update-test")
reqBody := map[string]string{
request := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{
"title": "Updated Title",
"content": "Updated content",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if postData, ok := data["post"].(map[string]any); ok {
if title, ok := postData["title"].(string); ok && title != "Updated Title" {
t.Errorf("Expected title 'Updated Title', got '%s'", title)
}
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if postData, ok := data["post"].(map[string]any); ok {
if title, ok := postData["title"].(string); ok && title != "Updated Title" {
t.Errorf("Expected title 'Updated Title', got '%s'", title)
}
}
}
@@ -298,20 +187,11 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Delete Test Post", "https://example.com/delete-test")
req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
request := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, request, http.StatusOK)
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil)
getRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRec, getReq)
assertStatus(t, getRec, http.StatusNotFound)
getRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID))
assertStatus(t, getRequest, http.StatusNotFound)
})
t.Run("Votes_Get_All_Endpoint", func(t *testing.T) {
@@ -319,35 +199,17 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "votes_get_all_user", "votes_get_all@example.com")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Votes Test Post", "https://example.com/votes-test")
makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/votes", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteBody := map[string]string{"type": "up"}
voteBodyBytes, _ := json.Marshal(voteBody)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBodyBytes))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
req := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if votes, ok := data["votes"].([]any); ok {
if len(votes) == 0 {
t.Error("Expected at least one vote in response")
}
} else {
t.Error("Expected votes array in response")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if votes, ok := data["votes"].([]any); ok {
if len(votes) == 0 {
t.Error("Expected at least one vote in response")
}
} else {
t.Error("Expected votes array in response")
}
}
})
@@ -358,49 +220,461 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Remove Test", "https://example.com/vote-remove")
voteBody := map[string]string{"type": "up"}
voteBodyBytes, _ := json.Marshal(voteBody)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBodyBytes))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
request := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, request, http.StatusOK)
})
t.Run("API_Info_Endpoint", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, ctx.Router, "/api")
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["endpoints"]; !exists {
t.Error("Expected endpoints in API info")
}
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if _, exists := data["endpoints"]; !exists {
t.Error("Expected endpoints in API info")
}
}
})
t.Run("Swagger_Documentation_Endpoint", func(t *testing.T) {
req := httptest.NewRequest("GET", "/swagger/index.html", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, ctx.Router, "/swagger/index.html")
assertStatusRange(t, request, http.StatusOK, http.StatusNotFound)
})
ctx.Router.ServeHTTP(rec, req)
t.Run("Search_Endpoint_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "search_edge"), uniqueTestEmail(t, "search_edge"))
assertStatusRange(t, rec, http.StatusOK, http.StatusNotFound)
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Searchable Post One", "https://example.com/one")
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Searchable Post Two", "https://example.com/two")
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Different Content", "https://example.com/three")
t.Run("Empty_Search_Results", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q=nonexistentterm12345")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) != 0 {
t.Errorf("Expected empty search results, got %d posts", len(posts))
}
}
if count, ok := data["count"].(float64); ok && count != 0 {
t.Errorf("Expected count 0, got %.0f", count)
}
}
})
t.Run("Search_With_Pagination", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q=Searchable&limit=1&offset=0")
response := assertJSONResponse(t, request, http.StatusOK)
var firstPostID any
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) > 1 {
t.Errorf("Expected at most 1 post with limit=1, got %d", len(posts))
}
if len(posts) > 0 {
if post, ok := posts[0].(map[string]any); ok {
firstPostID = post["id"]
}
}
}
if limit, ok := data["limit"].(float64); ok && limit != 1 {
t.Errorf("Expected limit 1 in response, got %.0f", limit)
}
if offset, ok := data["offset"].(float64); ok && offset != 0 {
t.Errorf("Expected offset 0 in response, got %.0f", offset)
}
}
secondRequest := makeGetRequest(t, ctx.Router, "/api/posts/search?q=Searchable&limit=1&offset=1")
secondResponse := assertJSONResponse(t, secondRequest, http.StatusOK)
if data, ok := getDataFromResponse(secondResponse); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) > 1 {
t.Errorf("Expected at most 1 post with limit=1 and offset=1, got %d", len(posts))
}
if len(posts) > 0 && firstPostID != nil {
if post, ok := posts[0].(map[string]any); ok {
if post["id"] == firstPostID {
t.Error("Expected different post with offset=1, got same post as offset=0")
}
}
}
}
}
})
t.Run("Search_With_Special_Characters", func(t *testing.T) {
specialQueries := []string{
"Searchable%20Post",
"Searchable'Post",
"Searchable\"Post",
"Searchable;Post",
"Searchable--Post",
}
for _, query := range specialQueries {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q="+url.QueryEscape(query))
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
}
})
t.Run("Search_Empty_Query", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q=")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) != 0 {
t.Errorf("Expected empty results for empty query, got %d posts", len(posts))
}
}
if count, ok := data["count"].(float64); ok && count != 0 {
t.Errorf("Expected count 0 for empty query, got %.0f", count)
}
}
})
t.Run("Search_With_Very_Long_Query", func(t *testing.T) {
longQuery := strings.Repeat("a", 1000)
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q="+url.QueryEscape(longQuery))
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Search_Case_Insensitive", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q=SEARCHABLE")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) == 0 {
t.Error("Expected case-insensitive search to find posts")
}
}
}
})
})
t.Run("Title_Fetch_Endpoint_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
t.Run("Missing_URL_Parameter", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/title")
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Empty_URL_Parameter", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url=")
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Invalid_URL_Format", func(t *testing.T) {
invalidURLs := []string{
"not-a-url",
"://invalid",
"http://",
"https://",
}
for _, invalidURL := range invalidURLs {
ctx.Suite.TitleFetcher.SetError(services.ErrUnsupportedScheme)
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url="+url.QueryEscape(invalidURL))
assertErrorResponse(t, request, http.StatusBadRequest)
}
})
t.Run("Unsupported_URL_Schemes", func(t *testing.T) {
unsupportedSchemes := []string{
"ftp://example.com",
"file:///etc/passwd",
"javascript:alert(1)",
"data:text/html,<script>alert(1)</script>",
}
for _, schemeURL := range unsupportedSchemes {
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url="+url.QueryEscape(schemeURL))
assertErrorResponse(t, request, http.StatusBadRequest)
}
})
t.Run("SSRF_Protection_Localhost", func(t *testing.T) {
ssrfURLs := []string{
"http://localhost",
"http://127.0.0.1",
"http://127.0.0.1:8080",
"http://[::1]",
"http://0.0.0.0",
}
for _, ssrfURL := range ssrfURLs {
ctx.Suite.TitleFetcher.SetError(services.ErrSSRFBlocked)
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url="+url.QueryEscape(ssrfURL))
assertStatusRange(t, request, http.StatusBadRequest, http.StatusBadGateway)
}
})
t.Run("SSRF_Protection_Private_IPs", func(t *testing.T) {
privateIPs := []string{
"http://192.168.1.1",
"http://10.0.0.1",
"http://172.16.0.1",
}
for _, privateIP := range privateIPs {
ctx.Suite.TitleFetcher.SetError(services.ErrSSRFBlocked)
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url="+url.QueryEscape(privateIP))
assertStatusRange(t, request, http.StatusBadRequest, http.StatusBadGateway)
}
})
t.Run("Title_Fetch_Error_Handling", func(t *testing.T) {
ctx.Suite.TitleFetcher.SetError(services.ErrTitleNotFound)
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url=https://example.com/notitle")
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Valid_URL_Success", func(t *testing.T) {
ctx.Suite.TitleFetcher.SetTitle("Valid Title")
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url=https://example.com/valid")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if title, ok := data["title"].(string); ok {
if title != "Valid Title" {
t.Errorf("Expected title 'Valid Title', got '%s'", title)
}
} else {
t.Error("Expected title in response data")
}
}
})
})
t.Run("Get_User_Vote_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "vote_edge"), uniqueTestEmail(t, "vote_edge"))
secondUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "vote_edge2"), uniqueTestEmail(t, "vote_edge2"))
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Edge Test Post", "https://example.com/vote-edge")
t.Run("Get_Vote_When_User_Has_Voted", func(t *testing.T) {
voteRequest := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, voteRequest, http.StatusOK)
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if hasVote, ok := data["has_vote"].(bool); !ok || !hasVote {
t.Error("Expected has_vote to be true when user has voted")
}
if vote, ok := data["vote"]; !ok || vote == nil {
t.Error("Expected vote object when user has voted")
}
}
})
t.Run("Get_Vote_When_User_Has_Not_Voted", func(t *testing.T) {
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if hasVote, ok := data["has_vote"].(bool); ok {
if hasVote {
t.Error("Expected has_vote to be false when user has not voted")
}
} else {
t.Error("Expected has_vote field in response")
}
}
})
t.Run("Get_Vote_Invalid_Post_ID", func(t *testing.T) {
request := makeAuthenticatedGetRequest(t, ctx.Router, "/api/posts/999999/vote", user, map[string]string{"id": "999999"})
if request.Code != http.StatusOK && request.Code != http.StatusNotFound {
t.Errorf("Expected status 200 or 404 for invalid post ID, got %d", request.Code)
}
})
t.Run("Get_Vote_Unauthenticated", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID))
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Get_Vote_Response_Structure", func(t *testing.T) {
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
response := assertJSONResponse(t, request, http.StatusOK)
if success, ok := response["success"].(bool); !ok || !success {
t.Error("Expected success field to be true")
}
if data, ok := getDataFromResponse(response); ok {
if _, exists := data["has_vote"]; !exists {
t.Error("Expected has_vote field in response data")
}
if _, exists := data["is_anonymous"]; !exists {
t.Error("Expected is_anonymous field in response data")
}
} else {
t.Error("Expected data field in response")
}
})
})
t.Run("Refresh_Token_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
refreshUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "refresh_edge"), uniqueTestEmail(t, "refresh_edge"))
t.Run("Refresh_With_Expired_Token", func(t *testing.T) {
loginResult, err := ctx.AuthService.Login(refreshUser.User.Username, "SecurePass123!")
if err != nil {
t.Fatalf("Failed to login: %v", err)
}
refreshToken, err := ctx.Suite.RefreshTokenRepo.GetByTokenHash(testutils.HashVerificationToken(loginResult.RefreshToken))
if err != nil {
t.Fatalf("Failed to get refresh token: %v", err)
}
refreshToken.ExpiresAt = time.Now().Add(-1 * time.Hour)
if err := ctx.Suite.DB.Model(refreshToken).Update("expires_at", refreshToken.ExpiresAt).Error; err != nil {
t.Fatalf("Failed to expire refresh token: %v", err)
}
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": loginResult.RefreshToken})
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Refresh_With_Revoked_Token", func(t *testing.T) {
loginResult, err := ctx.AuthService.Login(refreshUser.User.Username, "SecurePass123!")
if err != nil {
t.Fatalf("Failed to login: %v", err)
}
if err := ctx.AuthService.RevokeRefreshToken(loginResult.RefreshToken); err != nil {
t.Fatalf("Failed to revoke refresh token: %v", err)
}
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": loginResult.RefreshToken})
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Refresh_With_Empty_Token", func(t *testing.T) {
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": ""})
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Refresh_With_Missing_Token_Field", func(t *testing.T) {
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{})
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Refresh_Token_Rotation", func(t *testing.T) {
loginResult, err := ctx.AuthService.Login(refreshUser.User.Username, "SecurePass123!")
if err != nil {
t.Fatalf("Failed to login: %v", err)
}
originalRefreshToken := loginResult.RefreshToken
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": originalRefreshToken})
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if newAccessToken, ok := data["access_token"].(string); ok {
if newAccessToken == "" {
t.Error("Expected new access token in refresh response")
}
if newRefreshToken, ok := data["refresh_token"].(string); ok {
if newRefreshToken != "" && newRefreshToken == originalRefreshToken {
t.Log("Refresh token rotation may not be implemented (same token returned)")
}
}
}
}
})
t.Run("Refresh_After_Account_Lock", func(t *testing.T) {
lockedUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "refresh_lock"), uniqueTestEmail(t, "refresh_lock"))
loginResult, err := ctx.AuthService.Login(lockedUser.User.Username, "SecurePass123!")
if err != nil {
t.Fatalf("Failed to login: %v", err)
}
lockedUser.User.Locked = true
if err := ctx.Suite.UserRepo.Update(lockedUser.User); err != nil {
t.Fatalf("Failed to lock user: %v", err)
}
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": loginResult.RefreshToken})
assertStatusRange(t, request, http.StatusUnauthorized, http.StatusForbidden)
})
t.Run("Refresh_With_Invalid_Token_Format", func(t *testing.T) {
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": "invalid-token-format-12345"})
assertErrorResponse(t, request, http.StatusUnauthorized)
})
})
t.Run("Pagination_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
paginationUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "pagination_edge"), uniqueTestEmail(t, "pagination_edge"))
for i := 0; i < 5; i++ {
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, paginationUser.User.ID, fmt.Sprintf("Pagination Post %d", i), fmt.Sprintf("https://example.com/pag%d", i))
}
t.Run("Negative_Limit", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts?limit=-1")
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Negative_Offset", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts?offset=-1")
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Very_Large_Limit", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts?limit=10000")
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Very_Large_Offset", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts?offset=10000")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) > 0 {
t.Logf("Large offset returned %d posts (may be expected)", len(posts))
}
}
}
})
t.Run("Invalid_Pagination_Parameters", func(t *testing.T) {
invalidParams := []string{
"limit=abc",
"offset=xyz",
"limit=",
"offset=",
}
for _, param := range invalidParams {
request := makeGetRequest(t, ctx.Router, "/api/posts?"+param)
assertStatus(t, request, http.StatusOK)
}
})
})
}

View File

@@ -1,17 +1,12 @@
package integration
import (
"bytes"
"compress/gzip"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"goyco/internal/middleware"
"goyco/internal/testutils"
)
func TestIntegration_Compression(t *testing.T) {
@@ -19,16 +14,16 @@ func TestIntegration_Compression(t *testing.T) {
router := ctx.Router
t.Run("Response_Compression_Gzip", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil)
req.Header.Set("Accept-Encoding", "gzip")
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/api/posts", nil)
request.Header.Set("Accept-Encoding", "gzip")
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
contentEncoding := rec.Header().Get("Content-Encoding")
contentEncoding := recorder.Header().Get("Content-Encoding")
if contentEncoding != "" && strings.Contains(contentEncoding, "gzip") {
assertHeaderContains(t, rec, "Content-Encoding", "gzip")
reader, err := gzip.NewReader(rec.Body)
assertHeaderContains(t, recorder, "Content-Encoding", "gzip")
reader, err := gzip.NewReader(recorder.Body)
if err != nil {
t.Fatalf("Failed to create gzip reader: %v", err)
}
@@ -48,14 +43,14 @@ func TestIntegration_Compression(t *testing.T) {
})
t.Run("Compression_Headers_Present", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil)
req.Header.Set("Accept-Encoding", "gzip, deflate")
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/api/posts", nil)
request.Header.Set("Accept-Encoding", "gzip, deflate")
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Header().Get("Vary") != "" {
assertHeaderContains(t, rec, "Vary", "Accept-Encoding")
if recorder.Header().Get("Vary") != "" {
assertHeaderContains(t, recorder, "Vary", "Accept-Encoding")
} else {
t.Log("Vary header may not always be present")
}
@@ -67,25 +62,19 @@ func TestIntegration_StaticFiles(t *testing.T) {
router := ctx.Router
t.Run("Robots_Txt_Served", func(t *testing.T) {
req := httptest.NewRequest("GET", "/robots.txt", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/robots.txt")
router.ServeHTTP(rec, req)
assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
if !strings.Contains(rec.Body.String(), "User-agent") {
if !strings.Contains(request.Body.String(), "User-agent") {
t.Error("Expected robots.txt content")
}
})
t.Run("Static_Files_Security_Headers", func(t *testing.T) {
req := httptest.NewRequest("GET", "/robots.txt", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/robots.txt")
router.ServeHTTP(rec, req)
if rec.Header().Get("X-Content-Type-Options") == "" {
if request.Header().Get("X-Content-Type-Options") == "" {
t.Log("Security headers may not be applied to all static files")
}
})
@@ -101,32 +90,22 @@ func TestIntegration_URLMetadata(t *testing.T) {
ctx.Suite.TitleFetcher.SetTitle("Fetched Title")
postBody := map[string]string{
postBody := map[string]any{
"title": "Test Post",
"url": "https://example.com/metadata-test",
"content": "Test content",
}
body, _ := json.Marshal(postBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makePostRequest(t, router, "/api/posts", postBody, user, nil)
router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusCreated)
assertStatus(t, request, http.StatusCreated)
})
t.Run("URL_Metadata_Endpoint", func(t *testing.T) {
ctx.Suite.TitleFetcher.SetTitle("Endpoint Title")
req := httptest.NewRequest("GET", "/api/posts/title?url=https://example.com/test", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/api/posts/title?url=https://example.com/test")
router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, request, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["title"]; !exists {

View File

@@ -1,14 +1,10 @@
package integration
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"goyco/internal/middleware"
"goyco/internal/testutils"
)
@@ -22,33 +18,19 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, owner.User.ID, "Owner Post", "https://example.com/owner")
updateBody := map[string]string{
request := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{
"title": "Updated Title",
"content": "Updated content",
}
body, _ := json.Marshal(updateBody)
}, otherUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+otherUser.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, otherUser.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
assertErrorResponse(t, request, http.StatusForbidden)
ctx.Router.ServeHTTP(rec, req)
request = makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{
"title": "Updated Title",
"content": "Updated content",
}, owner, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertErrorResponse(t, rec, http.StatusForbidden)
req = httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+owner.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, owner.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
})
t.Run("Post_Delete_Authorization", func(t *testing.T) {
@@ -58,47 +40,27 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, owner.User.ID, "Delete Post", "https://example.com/delete")
req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil)
req.Header.Set("Authorization", "Bearer "+otherUser.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, otherUser.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
request := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), otherUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, request, http.StatusForbidden)
assertErrorResponse(t, rec, http.StatusForbidden)
request = makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), owner, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req = httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil)
req.Header.Set("Authorization", "Bearer "+owner.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, owner.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
})
t.Run("User_Profile_Access_Authorization", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user1 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "profile_user1", "profile_user1@example.com")
user2 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "profile_user2", "profile_user2@example.com")
firstUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "profile_user1", "profile_user1@example.com")
secondUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "profile_user2", "profile_user2@example.com")
req := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d", user1.User.ID), nil)
req.Header.Set("Authorization", "Bearer "+user2.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user2.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", user1.User.ID)})
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d", firstUser.User.ID), secondUser, map[string]string{"id": fmt.Sprintf("%d", firstUser.User.ID)})
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if userData, ok := data["user"].(map[string]any); ok {
if id, ok := userData["id"].(float64); ok && uint(id) != user1.User.ID {
t.Errorf("Expected user ID %d, got %.0f", user1.User.ID, id)
}
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if userData, ok := data["user"].(map[string]any); ok {
if id, ok := userData["id"].(float64); ok && uint(id) != firstUser.User.ID {
t.Errorf("Expected user ID %d, got %.0f", firstUser.User.ID, id)
}
}
}
@@ -109,24 +71,13 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "settings_auth_user", "settings_auth@example.com")
otherUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "settings_auth_other", "settings_auth_other@example.com")
updateBody := map[string]string{
"email": "newemail@example.com",
}
body, _ := json.Marshal(updateBody)
request := makePutRequest(t, ctx.Router, "/api/auth/email", map[string]any{"email": "newemail@example.com"}, otherUser, nil)
req := httptest.NewRequest("PUT", "/api/auth/email", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+otherUser.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, otherUser.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, request, http.StatusOK)
if response == nil {
return
}
if data, ok := response["data"].(map[string]any); ok {
if data, ok := getDataFromResponse(response); ok {
if userData, ok := data["user"].(map[string]any); ok {
if email, ok := userData["email"].(string); ok && email == "newemail@example.com" {
if id, ok := userData["id"].(float64); ok && uint(id) != otherUser.User.ID {
@@ -136,20 +87,9 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
}
}
updateBody2 := map[string]string{
"email": "anothernewemail@example.com",
}
body2, _ := json.Marshal(updateBody2)
request = makePutRequest(t, ctx.Router, "/api/auth/email", map[string]any{"email": "anothernewemail@example.com"}, user, nil)
req = httptest.NewRequest("PUT", "/api/auth/email", bytes.NewBuffer(body2))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
})
t.Run("Vote_Authorization", func(t *testing.T) {
@@ -159,73 +99,42 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, postOwner.User.ID, "Vote Auth Post", "https://example.com/vote-auth")
voteBody := map[string]string{"type": "up"}
body, _ := json.Marshal(voteBody)
request := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
assertStatus(t, request, http.StatusOK)
ctx.Router.ServeHTTP(rec, req)
request = makePostRequestWithJSON(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"})
assertStatus(t, rec, http.StatusOK)
req = httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Protected_Endpoint_Without_Auth", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("{}")))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, ctx.Router, "/api/posts", map[string]any{})
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Protected_Endpoint_With_Invalid_Token", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("{}")))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer invalid-token")
rec := httptest.NewRecorder()
request := makeRequest(t, ctx.Router, "POST", "/api/posts", []byte("{}"), map[string]string{"Content-Type": "application/json", "Authorization": "Bearer invalid-token"})
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("User_List_Authorization", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "list_auth_user", "list_auth@example.com")
req := httptest.NewRequest("GET", "/api/users", nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, ctx.Router, "/api/users", user, nil)
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
request = makeGetRequest(t, ctx.Router, "/api/users")
req = httptest.NewRequest("GET", "/api/users", nil)
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Refresh_Token_Authorization", func(t *testing.T) {
@@ -237,18 +146,9 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
t.Fatalf("Failed to login: %v", err)
}
refreshBody := map[string]string{
"refresh_token": loginResult.RefreshToken,
}
body, _ := json.Marshal(refreshBody)
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": loginResult.RefreshToken})
req := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, request, http.StatusOK)
if response == nil {
return
}
@@ -260,17 +160,8 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
t.Error("Expected data field in refresh response")
}
refreshBody = map[string]string{
"refresh_token": "invalid-refresh-token",
}
body, _ = json.Marshal(refreshBody)
request = makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": "invalid-refresh-token"})
req = httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, request, http.StatusUnauthorized)
})
}

View File

@@ -14,166 +14,137 @@ func TestIntegration_CSRF_Protection(t *testing.T) {
ctx := setupPageHandlerTestContext(t)
router := ctx.Router
t.Run("CSRF_Blocks_Form_Without_Token", func(t *testing.T) {
reqBody := url.Values{}
reqBody.Set("username", "testuser")
reqBody.Set("email", "test@example.com")
reqBody.Set("password", "SecurePass123!")
getCSRFToken := func(t *testing.T, path string, cookies ...*http.Cookie) *http.Cookie {
t.Helper()
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d. Body: %s", rec.Code, rec.Body.String())
request := httptest.NewRequest("GET", path, nil)
for _, c := range cookies {
request.AddCookie(c)
}
if !strings.Contains(rec.Body.String(), "Invalid CSRF token") {
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
for _, cookie := range recorder.Result().Cookies() {
if cookie.Name == "csrf_token" {
return cookie
}
}
t.Fatalf("Expected CSRF cookie to be set for %s", path)
return nil
}
t.Run("CSRF_Blocks_Form_Without_Token", func(t *testing.T) {
requestBody := url.Values{}
requestBody.Set("username", "testuser")
requestBody.Set("email", "test@example.com")
requestBody.Set("password", "SecurePass123!")
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d. Body: %s", recorder.Code, recorder.Body.String())
}
if !strings.Contains(recorder.Body.String(), "Invalid CSRF token") {
t.Error("Expected CSRF error message")
}
})
t.Run("CSRF_Allows_Form_With_Valid_Token", func(t *testing.T) {
getReq := httptest.NewRequest("GET", "/register", nil)
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
csrfCookie := getCSRFToken(t, "/register")
cookies := getRec.Result().Cookies()
var csrfCookie *http.Cookie
for _, cookie := range cookies {
if cookie.Name == "csrf_token" {
csrfCookie = cookie
break
}
}
requestBody := url.Values{}
requestBody.Set("username", "csrf_user")
requestBody.Set("email", "csrf@example.com")
requestBody.Set("password", "SecurePass123!")
requestBody.Set("csrf_token", csrfCookie.Value)
if csrfCookie == nil {
t.Fatal("Expected CSRF cookie to be set")
}
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
csrfToken := csrfCookie.Value
router.ServeHTTP(recorder, request)
reqBody := url.Values{}
reqBody.Set("username", "csrf_user")
reqBody.Set("email", "csrf@example.com")
reqBody.Set("password", "SecurePass123!")
reqBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code == http.StatusForbidden {
if recorder.Code == http.StatusForbidden {
t.Error("Expected form submission with valid CSRF token to succeed")
}
})
t.Run("CSRF_Allows_API_Requests", func(t *testing.T) {
reqBody := map[string]string{
requestBody := map[string]string{
"username": "api_user",
"email": "api@example.com",
"password": "SecurePass123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
body, _ := json.Marshal(requestBody)
request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Code == http.StatusForbidden {
if recorder.Code == http.StatusForbidden {
t.Error("Expected API requests to bypass CSRF protection")
}
})
t.Run("CSRF_Blocks_Mismatched_Token", func(t *testing.T) {
getReq := httptest.NewRequest("GET", "/register", nil)
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
csrfCookie := getCSRFToken(t, "/register")
cookies := getRec.Result().Cookies()
var csrfCookie *http.Cookie
for _, cookie := range cookies {
if cookie.Name == "csrf_token" {
csrfCookie = cookie
break
}
requestBody := url.Values{}
requestBody.Set("username", "mismatch_user")
requestBody.Set("email", "mismatch@example.com")
requestBody.Set("password", "SecurePass123!")
requestBody.Set("csrf_token", "wrong-token")
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d. Body: %s", recorder.Code, recorder.Body.String())
}
if csrfCookie == nil {
t.Fatal("Expected CSRF cookie to be set")
}
reqBody := url.Values{}
reqBody.Set("username", "mismatch_user")
reqBody.Set("email", "mismatch@example.com")
reqBody.Set("password", "SecurePass123!")
reqBody.Set("csrf_token", "wrong-token")
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d. Body: %s", rec.Code, rec.Body.String())
}
if !strings.Contains(rec.Body.String(), "Invalid CSRF token") {
if !strings.Contains(recorder.Body.String(), "Invalid CSRF token") {
t.Error("Expected CSRF error message")
}
})
t.Run("CSRF_Allows_GET_Requests", func(t *testing.T) {
req := httptest.NewRequest("GET", "/register", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/register", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Code == http.StatusForbidden {
if recorder.Code == http.StatusForbidden {
t.Error("Expected GET requests to bypass CSRF protection")
}
})
t.Run("CSRF_Token_In_Header", func(t *testing.T) {
getReq := httptest.NewRequest("GET", "/register", nil)
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
csrfCookie := getCSRFToken(t, "/register")
cookies := getRec.Result().Cookies()
var csrfCookie *http.Cookie
for _, cookie := range cookies {
if cookie.Name == "csrf_token" {
csrfCookie = cookie
break
}
}
requestBody := url.Values{}
requestBody.Set("username", "header_user")
requestBody.Set("email", "header@example.com")
requestBody.Set("password", "SecurePass123!")
if csrfCookie == nil {
t.Fatal("Expected CSRF cookie to be set")
}
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.Header.Set("X-CSRF-Token", csrfCookie.Value)
request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
csrfToken := csrfCookie.Value
router.ServeHTTP(recorder, request)
reqBody := url.Values{}
reqBody.Set("username", "header_user")
reqBody.Set("email", "header@example.com")
reqBody.Set("password", "SecurePass123!")
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("X-CSRF-Token", csrfToken)
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code == http.StatusForbidden {
if recorder.Code == http.StatusForbidden {
t.Error("Expected CSRF token in header to be accepted")
}
})
@@ -182,41 +153,24 @@ func TestIntegration_CSRF_Protection(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createUserWithCleanup(t, ctx, "csrf_form_user", "csrf_form@example.com")
getReq := httptest.NewRequest("GET", "/posts/new", nil)
getReq.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
authCookie := &http.Cookie{Name: "auth_token", Value: user.Token}
csrfCookie := getCSRFToken(t, "/posts/new", authCookie)
cookies := getRec.Result().Cookies()
var csrfCookie *http.Cookie
for _, cookie := range cookies {
if cookie.Name == "csrf_token" {
csrfCookie = cookie
break
}
}
requestBody := url.Values{}
requestBody.Set("title", "CSRF Test Post")
requestBody.Set("url", "https://example.com/csrf-test")
requestBody.Set("content", "Test content")
requestBody.Set("csrf_token", csrfCookie.Value)
if csrfCookie == nil {
t.Fatal("Expected CSRF cookie to be set")
}
request := httptest.NewRequest("POST", "/posts", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(authCookie)
request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
csrfToken := csrfCookie.Value
router.ServeHTTP(recorder, request)
reqBody := url.Values{}
reqBody.Set("title", "CSRF Test Post")
reqBody.Set("url", "https://example.com/csrf-test")
reqBody.Set("content", "Test content")
reqBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/posts", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code == http.StatusForbidden {
if recorder.Code == http.StatusForbidden {
t.Error("Expected post creation with valid CSRF token to succeed")
}
})

View File

@@ -19,27 +19,18 @@ func TestIntegration_DataConsistency(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "consistency_user", "consistency@example.com")
postBody := map[string]string{
request := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "Consistency Test Post",
"url": "https://example.com/consistency",
"content": "Test content",
}
body, _ := json.Marshal(postBody)
}, user, nil)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
createResponse := assertJSONResponse(t, rec, http.StatusCreated)
createResponse := assertJSONResponse(t, request, http.StatusCreated)
if createResponse == nil {
return
}
postData, ok := createResponse["data"].(map[string]any)
postData, ok := getDataFromResponse(createResponse)
if !ok {
t.Fatal("Response missing data")
}
@@ -53,16 +44,14 @@ func TestIntegration_DataConsistency(t *testing.T) {
createdURL := postData["url"]
createdContent := postData["content"]
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%.0f", postID), nil)
getRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRec, getReq)
getRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%.0f", postID))
getResponse := assertJSONResponse(t, getRec, http.StatusOK)
getResponse := assertJSONResponse(t, getRequest, http.StatusOK)
if getResponse == nil {
return
}
getPostData, ok := getResponse["data"].(map[string]any)
getPostData, ok := getDataFromResponse(getResponse)
if !ok {
t.Fatal("Get response missing data")
}
@@ -96,32 +85,17 @@ func TestIntegration_DataConsistency(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Consistency Post", "https://example.com/vote-consistency")
voteBody := map[string]string{"type": "up"}
body, _ := json.Marshal(voteBody)
voteRequest := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, voteRequest, http.StatusOK)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
getVotesRequest := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/votes", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, voteRec, http.StatusOK)
getVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesReq.Header.Set("Authorization", "Bearer "+user.Token)
getVotesReq = testutils.WithUserContext(getVotesReq, middleware.UserIDKey, user.User.ID)
getVotesReq = testutils.WithURLParams(getVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRec, getVotesReq)
votesResponse := assertJSONResponse(t, getVotesRec, http.StatusOK)
votesResponse := assertJSONResponse(t, getVotesRequest, http.StatusOK)
if votesResponse == nil {
return
}
votesData, ok := votesResponse["data"].(map[string]any)
votesData, ok := getDataFromResponse(votesResponse)
if !ok {
t.Fatal("Votes response missing data")
}
@@ -172,32 +146,21 @@ func TestIntegration_DataConsistency(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Original Title", "https://example.com/original")
updateBody := map[string]string{
updateRequest := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{
"title": "Updated Title",
"content": "Updated content",
}
body, _ := json.Marshal(updateBody)
}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
updateReq := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body))
updateReq.Header.Set("Content-Type", "application/json")
updateReq.Header.Set("Authorization", "Bearer "+user.Token)
updateReq = testutils.WithUserContext(updateReq, middleware.UserIDKey, user.User.ID)
updateReq = testutils.WithURLParams(updateReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
updateRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(updateRec, updateReq)
assertStatus(t, updateRequest, http.StatusOK)
assertStatus(t, updateRec, http.StatusOK)
getRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID))
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil)
getRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRec, getReq)
getResponse := assertJSONResponse(t, getRec, http.StatusOK)
getResponse := assertJSONResponse(t, getRequest, http.StatusOK)
if getResponse == nil {
return
}
getPostData, ok := getResponse["data"].(map[string]any)
getPostData, ok := getDataFromResponse(getResponse)
if !ok {
t.Fatal("Get response missing data")
}
@@ -215,18 +178,12 @@ func TestIntegration_DataConsistency(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "user_posts_consistency", "user_posts_consistency@example.com")
post1 := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Post 1", "https://example.com/post1")
post2 := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Post 2", "https://example.com/post2")
firstPost := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Post 1", "https://example.com/post1")
secondPost := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Post 2", "https://example.com/post2")
req := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d/posts", user.User.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d/posts", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, request, http.StatusOK)
if response == nil {
return
}
@@ -245,26 +202,26 @@ func TestIntegration_DataConsistency(t *testing.T) {
t.Errorf("Expected at least 2 posts, got %d", len(posts))
}
foundPost1 := false
foundPost2 := false
foundFirstPost := false
foundSecondPost := false
for _, post := range posts {
if postMap, ok := post.(map[string]any); ok {
if postID, ok := postMap["id"].(float64); ok {
if uint(postID) == post1.ID {
foundPost1 = true
if uint(postID) == firstPost.ID {
foundFirstPost = true
}
if uint(postID) == post2.ID {
foundPost2 = true
if uint(postID) == secondPost.ID {
foundSecondPost = true
}
}
}
}
if !foundPost1 {
if !foundFirstPost {
t.Error("Post 1 not found in user posts")
}
if !foundPost2 {
if !foundSecondPost {
t.Error("Post 2 not found in user posts")
}
})
@@ -275,20 +232,20 @@ func TestIntegration_DataConsistency(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Delete Consistency Post", "https://example.com/delete-consistency")
deleteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil)
deleteReq.Header.Set("Authorization", "Bearer "+user.Token)
deleteReq = testutils.WithUserContext(deleteReq, middleware.UserIDKey, user.User.ID)
deleteReq = testutils.WithURLParams(deleteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
deleteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(deleteRec, deleteReq)
deleteRequest := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil)
deleteRequest.Header.Set("Authorization", "Bearer "+user.Token)
deleteRequest = testutils.WithUserContext(deleteRequest, middleware.UserIDKey, user.User.ID)
deleteRequest = testutils.WithURLParams(deleteRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
deleteRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(deleteRecorder, deleteRequest)
assertStatus(t, deleteRec, http.StatusOK)
assertStatus(t, deleteRecorder, http.StatusOK)
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil)
getRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRec, getReq)
getRequest := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil)
getRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRecorder, getRequest)
assertStatus(t, getRec, http.StatusNotFound)
assertStatus(t, getRecorder, http.StatusNotFound)
})
t.Run("Vote_Removal_Consistency", func(t *testing.T) {
@@ -300,33 +257,33 @@ func TestIntegration_DataConsistency(t *testing.T) {
voteBody := map[string]string{"type": "up"}
body, _ := json.Marshal(voteBody)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
voteRequest := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
voteRequest.Header.Set("Content-Type", "application/json")
voteRequest.Header.Set("Authorization", "Bearer "+user.Token)
voteRequest = testutils.WithUserContext(voteRequest, middleware.UserIDKey, user.User.ID)
voteRequest = testutils.WithURLParams(voteRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRecorder, voteRequest)
assertStatus(t, voteRec, http.StatusOK)
assertStatus(t, voteRecorder, http.StatusOK)
removeVoteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
removeVoteReq.Header.Set("Authorization", "Bearer "+user.Token)
removeVoteReq = testutils.WithUserContext(removeVoteReq, middleware.UserIDKey, user.User.ID)
removeVoteReq = testutils.WithURLParams(removeVoteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
removeVoteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(removeVoteRec, removeVoteReq)
removeVoteRequest := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
removeVoteRequest.Header.Set("Authorization", "Bearer "+user.Token)
removeVoteRequest = testutils.WithUserContext(removeVoteRequest, middleware.UserIDKey, user.User.ID)
removeVoteRequest = testutils.WithURLParams(removeVoteRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
removeVoteRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(removeVoteRecorder, removeVoteRequest)
assertStatus(t, removeVoteRec, http.StatusOK)
assertStatus(t, removeVoteRecorder, http.StatusOK)
getVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesReq.Header.Set("Authorization", "Bearer "+user.Token)
getVotesReq = testutils.WithUserContext(getVotesReq, middleware.UserIDKey, user.User.ID)
getVotesReq = testutils.WithURLParams(getVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRec, getVotesReq)
getVotesRequest := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesRequest.Header.Set("Authorization", "Bearer "+user.Token)
getVotesRequest = testutils.WithUserContext(getVotesRequest, middleware.UserIDKey, user.User.ID)
getVotesRequest = testutils.WithURLParams(getVotesRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRecorder, getVotesRequest)
votesResponse := assertJSONResponse(t, getVotesRec, http.StatusOK)
votesResponse := assertJSONResponse(t, getVotesRecorder, http.StatusOK)
if votesResponse == nil {
return
}

View File

@@ -22,41 +22,39 @@ func TestIntegration_EdgeCases(t *testing.T) {
expiredToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2MDAwMDAwMDB9.expired"
req := httptest.NewRequest("GET", "/api/auth/me", nil)
req.Header.Set("Authorization", "Bearer "+expiredToken)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/api/auth/me", nil)
request.Header.Set("Authorization", "Bearer "+expiredToken)
recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, recorder, http.StatusUnauthorized)
})
t.Run("Concurrent_Vote_Operations", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user1 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "vote_user1", "vote1@example.com")
firstUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "vote_user1", "vote1@example.com")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user1.User.ID, "Concurrent Vote Post", "https://example.com/concurrent")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, firstUser.User.ID, "Concurrent Vote Post", "https://example.com/concurrent")
var wg sync.WaitGroup
errors := make(chan error, 10)
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for range 5 {
wg.Go(func() {
voteBody := map[string]string{"type": "up"}
body, _ := json.Marshal(voteBody)
req := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user1.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user1.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
errors <- fmt.Errorf("unexpected status: %d", rec.Code)
request := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+firstUser.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, firstUser.User.ID)
request = testutils.WithURLParams(request, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(recorder, request)
if recorder.Code != http.StatusOK {
errors <- fmt.Errorf("unexpected status: %d", recorder.Code)
}
}()
})
}
wg.Wait()
@@ -72,8 +70,8 @@ func TestIntegration_EdgeCases(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "large_user", "large@example.com")
largeContent := make([]byte, 10001)
for i := range largeContent {
largeContent[i] = 'a'
for idx := range largeContent {
largeContent[idx] = 'a'
}
postBody := map[string]string{
@@ -82,36 +80,36 @@ func TestIntegration_EdgeCases(t *testing.T) {
"content": string(largeContent),
}
body, _ := json.Marshal(postBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusBadRequest)
assertErrorResponse(t, recorder, http.StatusBadRequest)
smallContent := make([]byte, 1000)
for i := range smallContent {
smallContent[i] = 'a'
for idx := range smallContent {
smallContent[idx] = 'a'
}
postBody2 := map[string]string{
secondPostBody := map[string]string{
"title": "Small Post",
"url": "https://example.com/small",
"content": string(smallContent),
}
body2, _ := json.Marshal(postBody2)
req2 := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body2))
req2.Header.Set("Content-Type", "application/json")
req2.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID)
rec2 := httptest.NewRecorder()
secondBody, _ := json.Marshal(secondPostBody)
secondRequest := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(secondBody))
secondRequest.Header.Set("Content-Type", "application/json")
secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
secondRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec2, req2)
ctx.Router.ServeHTTP(secondRecorder, secondRequest)
assertStatus(t, rec2, http.StatusCreated)
assertStatus(t, secondRecorder, http.StatusCreated)
})
t.Run("Malformed_JSON_Payloads", func(t *testing.T) {
@@ -127,15 +125,15 @@ func TestIntegration_EdgeCases(t *testing.T) {
}
for _, payload := range malformedPayloads {
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBufferString(payload))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBufferString(payload))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusBadRequest)
assertErrorResponse(t, recorder, http.StatusBadRequest)
}
})
@@ -148,38 +146,36 @@ func TestIntegration_EdgeCases(t *testing.T) {
voteBody := map[string]string{"type": "up"}
body, _ := json.Marshal(voteBody)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
assertStatus(t, voteRec, http.StatusOK)
voteRequest := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
voteRequest.Header.Set("Content-Type", "application/json")
voteRequest.Header.Set("Authorization", "Bearer "+user.Token)
voteRequest = testutils.WithUserContext(voteRequest, middleware.UserIDKey, user.User.ID)
voteRequest = testutils.WithURLParams(voteRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRecorder, voteRequest)
assertStatus(t, voteRecorder, http.StatusOK)
var wg sync.WaitGroup
for i := 0; i < 3; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
}()
for range 3 {
wg.Go(func() {
request := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
request = testutils.WithURLParams(request, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(recorder, request)
})
}
wg.Wait()
getVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesReq.Header.Set("Authorization", "Bearer "+user.Token)
getVotesReq = testutils.WithUserContext(getVotesReq, middleware.UserIDKey, user.User.ID)
getVotesReq = testutils.WithURLParams(getVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRec, getVotesReq)
getVotesRequest := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesRequest.Header.Set("Authorization", "Bearer "+user.Token)
getVotesRequest = testutils.WithUserContext(getVotesRequest, middleware.UserIDKey, user.User.ID)
getVotesRequest = testutils.WithURLParams(getVotesRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRecorder, getVotesRequest)
votesResponse := assertJSONResponse(t, getVotesRec, http.StatusOK)
votesResponse := assertJSONResponse(t, getVotesRecorder, http.StatusOK)
if votesResponse != nil {
if data, ok := votesResponse["data"].(map[string]any); ok {
if votes, ok := data["votes"].([]any); ok {

View File

@@ -1,14 +1,10 @@
package integration
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"goyco/internal/database"
"goyco/internal/middleware"
"goyco/internal/testutils"
)
@@ -19,19 +15,14 @@ func TestIntegration_EmailService(t *testing.T) {
t.Run("Registration_Email_Sent", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{
reqBody := map[string]any{
"username": "email_reg_user",
"email": "email_reg@example.com",
"password": "SecurePass123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, router, "/api/auth/register", reqBody)
router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusCreated)
assertStatus(t, request, http.StatusCreated)
token := ctx.Suite.EmailSender.VerificationToken()
if token == "" {
@@ -52,15 +43,10 @@ func TestIntegration_EmailService(t *testing.T) {
t.Fatalf("Failed to create user: %v", err)
}
reqBody := map[string]string{
reqBody := map[string]any{
"username_or_email": "email_reset_user",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/forgot-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
makePostRequestWithJSON(t, router, "/api/auth/forgot-password", reqBody)
token := ctx.Suite.EmailSender.PasswordResetToken()
if token == "" {
@@ -72,17 +58,9 @@ func TestIntegration_EmailService(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "email_del_user", "email_del@example.com")
reqBody := map[string]string{}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makeDeleteRequest(t, router, "/api/auth/account", user, nil)
router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
token := ctx.Suite.EmailSender.DeletionToken()
if token == "" {
@@ -94,17 +72,10 @@ func TestIntegration_EmailService(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "email_change_user", "email_change@example.com")
reqBody := map[string]string{
reqBody := map[string]any{
"email": "newemail@example.com",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/auth/email", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
makePutRequest(t, router, "/api/auth/email", reqBody, user, nil)
token := ctx.Suite.EmailSender.VerificationToken()
if token == "" {
@@ -115,17 +86,12 @@ func TestIntegration_EmailService(t *testing.T) {
t.Run("Email_Template_Content", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{
reqBody := map[string]any{
"username": "template_user",
"email": "template@example.com",
"password": "SecurePass123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
makePostRequestWithJSON(t, router, "/api/auth/register", reqBody)
token := ctx.Suite.EmailSender.VerificationToken()
if token == "" {

View File

@@ -1,7 +1,6 @@
package integration
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
@@ -10,7 +9,7 @@ import (
"strings"
"testing"
"goyco/internal/middleware"
"goyco/internal/database"
"goyco/internal/testutils"
)
@@ -20,46 +19,34 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Run("Complete_Registration_To_Post_Creation_Journey", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
registerBody := map[string]string{
registerRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "journey_user",
"email": "journey@example.com",
"password": "SecurePass123!",
}
body, _ := json.Marshal(registerBody)
registerReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
registerReq.Header.Set("Content-Type", "application/json")
registerRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(registerRec, registerReq)
})
assertStatus(t, registerRec, http.StatusCreated)
assertStatus(t, registerRequest, http.StatusCreated)
verificationToken := ctx.Suite.EmailSender.VerificationToken()
if verificationToken == "" {
t.Fatal("Verification token not sent")
}
confirmReq := httptest.NewRequest("GET", "/api/auth/confirm?token="+url.QueryEscape(verificationToken), nil)
confirmRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(confirmRec, confirmReq)
confirmRequest := makeGetRequest(t, ctx.Router, "/api/auth/confirm?token="+url.QueryEscape(verificationToken))
assertStatus(t, confirmRec, http.StatusOK)
assertStatus(t, confirmRequest, http.StatusOK)
loginBody := map[string]string{
loginRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/login", map[string]any{
"username": "journey_user",
"password": "SecurePass123!",
}
loginBodyBytes, _ := json.Marshal(loginBody)
loginReq := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBuffer(loginBodyBytes))
loginReq.Header.Set("Content-Type", "application/json")
loginRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(loginRec, loginReq)
})
loginResponse := assertJSONResponse(t, loginRec, http.StatusOK)
loginResponse := assertJSONResponse(t, loginRequest, http.StatusOK)
if loginResponse == nil {
return
}
data, ok := loginResponse["data"].(map[string]any)
data, ok := getDataFromResponse(loginResponse)
if !ok {
t.Fatal("Login response missing data")
}
@@ -67,8 +54,8 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
var token string
if accessToken, ok := data["access_token"].(string); ok && accessToken != "" {
token = accessToken
} else if tokenVal, ok := data["token"].(string); ok && tokenVal != "" {
token = tokenVal
} else if tokenValue, ok := data["token"].(string); ok && tokenValue != "" {
token = tokenValue
} else {
t.Fatal("Login response missing access_token or token")
}
@@ -90,25 +77,19 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Fatalf("Login response missing user.id. Data: %+v", data)
}
postBody := map[string]string{
postBodyBytes, _ := json.Marshal(map[string]any{
"title": "Journey Test Post",
"url": "https://example.com/journey",
"content": "Test content",
}
postBodyBytes, _ := json.Marshal(postBody)
postReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(postBodyBytes))
postReq.Header.Set("Content-Type", "application/json")
postReq.Header.Set("Authorization", "Bearer "+token)
postReq = testutils.WithUserContext(postReq, middleware.UserIDKey, uint(userID))
postRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(postRec, postReq)
})
postRequest := makeAuthenticatedRequest(t, ctx.Router, "POST", "/api/posts", postBodyBytes, &authenticatedUser{User: &database.User{ID: uint(userID)}, Token: token}, nil)
postResponse := assertJSONResponse(t, postRec, http.StatusCreated)
postResponse := assertJSONResponse(t, postRequest, http.StatusCreated)
if postResponse == nil {
return
}
postData, ok := postResponse["data"].(map[string]any)
postData, ok := getDataFromResponse(postResponse)
if !ok {
t.Fatal("Post response missing data")
}
@@ -118,56 +99,39 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Fatal("Post response missing id")
}
getPostReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%.0f", postID), nil)
getPostRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getPostRec, getPostReq)
getPostRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%.0f", postID))
assertStatus(t, getPostRec, http.StatusOK)
assertStatus(t, getPostRequest, http.StatusOK)
})
t.Run("Complete_Password_Reset_Journey", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "reset_journey_user", "reset_journey@example.com")
resetBody := map[string]string{
resetRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/forgot-password", map[string]any{
"username_or_email": "reset_journey@example.com",
}
resetBodyBytes, _ := json.Marshal(resetBody)
resetReq := httptest.NewRequest("POST", "/api/auth/forgot-password", bytes.NewBuffer(resetBodyBytes))
resetReq.Header.Set("Content-Type", "application/json")
resetRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(resetRec, resetReq)
})
assertStatus(t, resetRec, http.StatusOK)
assertStatus(t, resetRequest, http.StatusOK)
resetToken := ctx.Suite.EmailSender.GetLastPasswordResetToken()
if resetToken == "" {
t.Fatal("Password reset token not sent")
}
newPasswordBody := map[string]string{
newPasswordRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/reset-password", map[string]any{
"token": resetToken,
"new_password": "NewSecurePass123!",
}
newPasswordBodyBytes, _ := json.Marshal(newPasswordBody)
newPasswordReq := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(newPasswordBodyBytes))
newPasswordReq.Header.Set("Content-Type", "application/json")
newPasswordRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(newPasswordRec, newPasswordReq)
})
assertStatus(t, newPasswordRec, http.StatusOK)
assertStatus(t, newPasswordRequest, http.StatusOK)
loginBody := map[string]string{
loginRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/login", map[string]any{
"username": "reset_journey_user",
"password": "NewSecurePass123!",
}
loginBodyBytes, _ := json.Marshal(loginBody)
loginReq := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBuffer(loginBodyBytes))
loginReq.Header.Set("Content-Type", "application/json")
loginRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(loginRec, loginReq)
})
assertStatus(t, loginRec, http.StatusOK)
assertStatus(t, loginRequest, http.StatusOK)
})
t.Run("Complete_Vote_And_Unvote_Journey", func(t *testing.T) {
@@ -176,40 +140,21 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Journey Post", "https://example.com/vote-journey")
voteBody := map[string]string{"type": "up"}
voteBodyBytes, _ := json.Marshal(voteBody)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBodyBytes))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
voteRequest := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, voteRequest, http.StatusOK)
assertStatus(t, voteRec, http.StatusOK)
getVotesRequest := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/votes", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesReq.Header.Set("Authorization", "Bearer "+user.Token)
getVotesReq = testutils.WithUserContext(getVotesReq, middleware.UserIDKey, user.User.ID)
getVotesReq = testutils.WithURLParams(getVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRec, getVotesReq)
votesResponse := assertJSONResponse(t, getVotesRec, http.StatusOK)
votesResponse := assertJSONResponse(t, getVotesRequest, http.StatusOK)
if votesResponse == nil {
return
}
if data, ok := votesResponse["data"].(map[string]any); ok {
if data, ok := getDataFromResponse(votesResponse); ok {
if votes, ok := data["votes"].([]any); ok && len(votes) > 0 {
unvoteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
unvoteReq.Header.Set("Authorization", "Bearer "+user.Token)
unvoteReq = testutils.WithUserContext(unvoteReq, middleware.UserIDKey, user.User.ID)
unvoteReq = testutils.WithURLParams(unvoteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
unvoteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(unvoteRec, unvoteReq)
unvoteRequest := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, unvoteRec, http.StatusOK)
assertStatus(t, unvoteRequest, http.StatusOK)
}
}
})
@@ -221,31 +166,31 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
csrfToken := getCSRFToken(t, pageRouter, "/register")
reqBody := url.Values{}
reqBody.Set("username", "page_journey_user")
reqBody.Set("email", "page_journey@example.com")
reqBody.Set("password", "SecurePass123!")
reqBody.Set("password_confirm", "SecurePass123!")
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("username", "page_journey_user")
requestBody.Set("email", "page_journey@example.com")
requestBody.Set("password", "SecurePass123!")
requestBody.Set("password_confirm", "SecurePass123!")
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
pageRouter.ServeHTTP(rec, req)
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
pageRouter.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
verificationToken := pageCtx.Suite.EmailSender.VerificationToken()
if verificationToken == "" {
t.Fatal("Verification token not sent")
}
confirmReq := httptest.NewRequest("GET", "/confirm?token="+url.QueryEscape(verificationToken), nil)
confirmRec := httptest.NewRecorder()
pageRouter.ServeHTTP(confirmRec, confirmReq)
confirmRequest := httptest.NewRequest("GET", "/confirm?token="+url.QueryEscape(verificationToken), nil)
confirmRecorder := httptest.NewRecorder()
pageRouter.ServeHTTP(confirmRecorder, confirmRequest)
assertStatusRange(t, confirmRec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, confirmRecorder, http.StatusOK, http.StatusSeeOther)
loginCSRFToken := getCSRFToken(t, pageRouter, "/login")
@@ -254,15 +199,15 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
loginBody.Set("password", "SecurePass123!")
loginBody.Set("csrf_token", loginCSRFToken)
loginReq := httptest.NewRequest("POST", "/login", strings.NewReader(loginBody.Encode()))
loginReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
loginReq.AddCookie(&http.Cookie{Name: "csrf_token", Value: loginCSRFToken})
loginRec := httptest.NewRecorder()
pageRouter.ServeHTTP(loginRec, loginReq)
loginRequest := httptest.NewRequest("POST", "/login", strings.NewReader(loginBody.Encode()))
loginRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded")
loginRequest.AddCookie(&http.Cookie{Name: "csrf_token", Value: loginCSRFToken})
loginRecorder := httptest.NewRecorder()
pageRouter.ServeHTTP(loginRecorder, loginRequest)
assertStatus(t, loginRec, http.StatusSeeOther)
assertStatus(t, loginRecorder, http.StatusSeeOther)
loginCookies := loginRec.Result().Cookies()
loginCookies := loginRecorder.Result().Cookies()
var authToken string
for _, cookie := range loginCookies {
if cookie.Name == "auth_token" {
@@ -275,37 +220,31 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Fatal("Auth token not set after login")
}
homeReq := httptest.NewRequest("GET", "/", nil)
homeReq.AddCookie(&http.Cookie{Name: "auth_token", Value: authToken})
homeRec := httptest.NewRecorder()
pageRouter.ServeHTTP(homeRec, homeReq)
homeRequest := httptest.NewRequest("GET", "/", nil)
homeRequest.AddCookie(&http.Cookie{Name: "auth_token", Value: authToken})
homeRecorder := httptest.NewRecorder()
pageRouter.ServeHTTP(homeRecorder, homeRequest)
assertStatus(t, homeRec, http.StatusOK)
assertStatus(t, homeRecorder, http.StatusOK)
})
t.Run("Complete_Post_Creation_And_Update_Journey", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "post_update_journey_user", "post_update_journey@example.com")
postBody := map[string]string{
postBodyBytes, _ := json.Marshal(map[string]any{
"title": "Original Title",
"url": "https://example.com/original",
"content": "Original content",
}
postBodyBytes, _ := json.Marshal(postBody)
postReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(postBodyBytes))
postReq.Header.Set("Content-Type", "application/json")
postReq.Header.Set("Authorization", "Bearer "+user.Token)
postReq = testutils.WithUserContext(postReq, middleware.UserIDKey, user.User.ID)
postRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(postRec, postReq)
})
postRequest := makeAuthenticatedRequest(t, ctx.Router, "POST", "/api/posts", postBodyBytes, user, nil)
postResponse := assertJSONResponse(t, postRec, http.StatusCreated)
postResponse := assertJSONResponse(t, postRequest, http.StatusCreated)
if postResponse == nil {
return
}
postData, ok := postResponse["data"].(map[string]any)
postData, ok := getDataFromResponse(postResponse)
if !ok {
t.Fatal("Post response missing data")
}
@@ -315,34 +254,25 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Fatal("Post response missing id")
}
updateBody := map[string]string{
updateBodyBytes, _ := json.Marshal(map[string]any{
"title": "Updated Title",
"content": "Updated content",
}
updateBodyBytes, _ := json.Marshal(updateBody)
updateReq := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%.0f", postID), bytes.NewBuffer(updateBodyBytes))
updateReq.Header.Set("Content-Type", "application/json")
updateReq.Header.Set("Authorization", "Bearer "+user.Token)
updateReq = testutils.WithUserContext(updateReq, middleware.UserIDKey, user.User.ID)
updateReq = testutils.WithURLParams(updateReq, map[string]string{"id": fmt.Sprintf("%.0f", postID)})
updateRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(updateRec, updateReq)
})
updateRequest := makeAuthenticatedRequest(t, ctx.Router, "PUT", fmt.Sprintf("/api/posts/%.0f", postID), updateBodyBytes, user, map[string]string{"id": fmt.Sprintf("%.0f", postID)})
updateResponse := assertJSONResponse(t, updateRec, http.StatusOK)
updateResponse := assertJSONResponse(t, updateRequest, http.StatusOK)
if updateResponse == nil {
return
}
getPostReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%.0f", postID), nil)
getPostRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getPostRec, getPostReq)
getPostRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%.0f", postID))
getPostResponse := assertJSONResponse(t, getPostRec, http.StatusOK)
getPostResponse := assertJSONResponse(t, getPostRequest, http.StatusOK)
if getPostResponse == nil {
return
}
if data, ok := getPostResponse["data"].(map[string]any); ok {
if data, ok := getDataFromResponse(getPostResponse); ok {
if post, ok := data["post"].(map[string]any); ok {
if title, ok := post["title"].(string); ok && title != "Updated Title" {
t.Errorf("Post title not updated: expected 'Updated Title', got '%s'", title)

View File

@@ -2,7 +2,6 @@ package integration
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
@@ -19,58 +18,46 @@ func TestIntegration_ErrorPropagation(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "json_error_user", "json_error@example.com")
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("invalid json{")))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("invalid json{")))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusBadRequest)
assertErrorResponse(t, recorder, http.StatusBadRequest)
})
t.Run("Validation_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{
reqBody := map[string]any{
"username": "",
"email": "invalid-email",
"password": "weak",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", reqBody)
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusBadRequest)
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Database_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "db_error_user", "db_error@example.com")
reqBody := map[string]string{
reqBody := map[string]any{
"title": "Test Post",
"url": "https://example.com/test",
"content": "Test content",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makePostRequest(t, ctx.Router, "/api/posts", reqBody, user, nil)
ctx.Router.ServeHTTP(rec, req)
if rec.Code == http.StatusInternalServerError {
assertErrorResponse(t, rec, http.StatusInternalServerError)
if request.Code == http.StatusInternalServerError {
assertErrorResponse(t, request, http.StatusInternalServerError)
} else {
assertStatus(t, rec, http.StatusCreated)
assertStatus(t, request, http.StatusCreated)
}
})
@@ -78,34 +65,23 @@ func TestIntegration_ErrorPropagation(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "notfound_error_user", "notfound_error@example.com")
req := httptest.NewRequest("GET", "/api/posts/999999", nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": "999999"})
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, ctx.Router, "/api/posts/999999", user, map[string]string{"id": "999999"})
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusNotFound)
assertErrorResponse(t, request, http.StatusNotFound)
})
t.Run("Unauthorized_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{
reqBody := map[string]any{
"title": "Test Post",
"url": "https://example.com/test",
"content": "Test content",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, ctx.Router, "/api/posts", reqBody)
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Forbidden_Error_Propagation", func(t *testing.T) {
@@ -115,78 +91,58 @@ func TestIntegration_ErrorPropagation(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, owner.User.ID, "Forbidden Post", "https://example.com/forbidden")
updateBody := map[string]string{
updateBody := map[string]any{
"title": "Updated Title",
"content": "Updated content",
}
body, _ := json.Marshal(updateBody)
req := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+otherUser.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, otherUser.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
request := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), updateBody, otherUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusForbidden)
assertErrorResponse(t, request, http.StatusForbidden)
})
t.Run("Service_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{
reqBody := map[string]any{
"username": "existing_user",
"email": "existing@example.com",
"password": "SecurePass123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", reqBody)
assertStatus(t, rec, http.StatusCreated)
assertStatus(t, request, http.StatusCreated)
req = httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
request = makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", reqBody)
assertStatusRange(t, rec, http.StatusBadRequest, http.StatusConflict)
assertErrorResponse(t, rec, rec.Code)
assertStatusRange(t, request, http.StatusBadRequest, http.StatusConflict)
assertErrorResponse(t, request, request.Code)
})
t.Run("Middleware_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("{}")))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer expired.invalid.token")
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("{}")))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer expired.invalid.token")
recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, recorder, http.StatusUnauthorized)
})
t.Run("Handler_Error_Response_Format", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
req := httptest.NewRequest("GET", "/api/nonexistent", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, ctx.Router, "/api/nonexistent")
ctx.Router.ServeHTTP(rec, req)
if rec.Code == http.StatusNotFound {
if rec.Header().Get("Content-Type") == "application/json" {
assertErrorResponse(t, rec, http.StatusNotFound)
} else {
if rec.Body.Len() == 0 {
t.Error("Expected error response body")
}
if request.Code == http.StatusNotFound {
if request.Header().Get("Content-Type") == "application/json" {
assertErrorResponse(t, request, http.StatusNotFound)
} else if request.Body.Len() == 0 {
t.Error("Expected error response body")
}
}
})

View File

@@ -11,55 +11,35 @@ import (
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"goyco/internal/database"
"goyco/internal/handlers"
"goyco/internal/middleware"
"goyco/internal/repositories"
"goyco/internal/services"
"goyco/internal/testutils"
"github.com/golang-jwt/jwt/v5"
)
func TestIntegration_Handlers(t *testing.T) {
suite := testutils.NewServiceSuite(t)
authService, err := services.NewAuthFacadeForTest(testutils.AppTestConfig, suite.UserRepo, suite.PostRepo, suite.DeletionRepo, suite.RefreshTokenRepo, suite.EmailSender)
if err != nil {
t.Fatalf("Failed to create auth service: %v", err)
}
voteService := services.NewVoteService(suite.VoteRepo, suite.PostRepo, suite.DB)
emailSender := suite.EmailSender
userRepo := suite.UserRepo
postRepo := suite.PostRepo
titleFetcher := suite.TitleFetcher
authHandler := handlers.NewAuthHandler(authService, userRepo)
postHandler := handlers.NewPostHandler(postRepo, titleFetcher, voteService)
voteHandler := handlers.NewVoteHandler(voteService)
userHandler := handlers.NewUserHandler(userRepo, authService)
ctx := setupTestContext(t)
authService := ctx.AuthService
emailSender := ctx.Suite.EmailSender
userRepo := ctx.Suite.UserRepo
postRepo := ctx.Suite.PostRepo
t.Run("Auth_Handler_Complete_Workflow", func(t *testing.T) {
emailSender.Reset()
registerData := map[string]string{
registerResponse := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "handler_user",
"email": "handler@example.com",
"password": "SecurePass123!",
}
registerBody, _ := json.Marshal(registerData)
registerReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(registerBody))
registerReq.Header.Set("Content-Type", "application/json")
registerResp := httptest.NewRecorder()
authHandler.Register(registerResp, registerReq)
if registerResp.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", registerResp.Code)
})
if registerResponse.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", registerResponse.Code)
}
var registerPayload map[string]any
if err := json.Unmarshal(registerResp.Body.Bytes(), &registerPayload); err != nil {
t.Fatalf("Failed to decode register response: %v", err)
}
registerPayload := assertJSONResponse(t, registerResponse, http.StatusCreated)
if success, _ := registerPayload["success"].(bool); !success {
t.Fatalf("Expected register response success, got %v", registerPayload)
}
@@ -78,11 +58,9 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatalf("Failed to update user with mock token: %v", err)
}
confirmReq := httptest.NewRequest(http.MethodGet, "/api/auth/confirm?token="+url.QueryEscape(mockToken), nil)
confirmResp := httptest.NewRecorder()
authHandler.ConfirmEmail(confirmResp, confirmReq)
if confirmResp.Code != http.StatusOK {
t.Fatalf("Expected 200 when confirming email via handler, got %d", confirmResp.Code)
confirmResponse := makeGetRequest(t, ctx.Router, "/api/auth/confirm?token="+url.QueryEscape(mockToken))
if confirmResponse.Code != http.StatusOK {
t.Fatalf("Expected 200 when confirming email via handler, got %d", confirmResponse.Code)
}
loginSeed := createAuthenticatedUser(t, authService, userRepo, "auth_handler_login", "auth_handler_login@example.com")
@@ -92,92 +70,60 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatalf("Service login failed for seeded user: %v", err)
}
meReq := httptest.NewRequest("GET", "/api/auth/me", nil)
meReq.Header.Set("Authorization", "Bearer "+loginAuth.AccessToken)
meReq = testutils.WithUserContext(meReq, middleware.UserIDKey, loginSeed.User.ID)
meResp := httptest.NewRecorder()
authHandler.Me(meResp, meReq)
if meResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", meResp.Code)
meResponse := makeAuthenticatedGetRequest(t, ctx.Router, "/api/auth/me", &authenticatedUser{User: loginSeed.User, Token: loginAuth.AccessToken}, nil)
if meResponse.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", meResponse.Code)
}
})
t.Run("Auth_Handler_Security_Validation", func(t *testing.T) {
emailSender.Reset()
weakData := map[string]string{
weakResponse := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "weak_user",
"email": "weak@example.com",
"password": "123",
}
weakBody, _ := json.Marshal(weakData)
weakReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(weakBody))
weakReq.Header.Set("Content-Type", "application/json")
weakResp := httptest.NewRecorder()
authHandler.Register(weakResp, weakReq)
if weakResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for weak password, got %d", weakResp.Code)
})
if weakResponse.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for weak password, got %d", weakResponse.Code)
}
var weakErrorResp map[string]any
if err := json.Unmarshal(weakResp.Body.Bytes(), &weakErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := weakErrorResp["success"].(bool); success {
weakErrorResponse := assertJSONResponse(t, weakResponse, http.StatusBadRequest)
if success, _ := weakErrorResponse["success"].(bool); success {
t.Error("Expected error response to have success=false")
}
if errorMsg, ok := weakErrorResp["error"].(string); !ok || errorMsg == "" {
if errorMsg, ok := weakErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain validation error message")
}
invalidData := map[string]string{
invalidResponse := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "invalid_user",
"email": "not-an-email",
"password": "SecurePass123!",
}
invalidBody, _ := json.Marshal(invalidData)
invalidReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(invalidBody))
invalidReq.Header.Set("Content-Type", "application/json")
invalidResp := httptest.NewRecorder()
authHandler.Register(invalidResp, invalidReq)
if invalidResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid email, got %d", invalidResp.Code)
})
if invalidResponse.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid email, got %d", invalidResponse.Code)
}
var invalidEmailErrorResp map[string]any
if err := json.Unmarshal(invalidResp.Body.Bytes(), &invalidEmailErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := invalidEmailErrorResp["success"].(bool); success {
invalidEmailErrorResponse := assertJSONResponse(t, invalidResponse, http.StatusBadRequest)
if success, _ := invalidEmailErrorResponse["success"].(bool); success {
t.Error("Expected error response to have success=false")
}
if errorMsg, ok := invalidEmailErrorResp["error"].(string); !ok || errorMsg == "" {
if errorMsg, ok := invalidEmailErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain validation error message")
}
incompleteData := map[string]string{
incompleteResponse := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "incomplete_user",
}
incompleteBody, _ := json.Marshal(incompleteData)
incompleteReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(incompleteBody))
incompleteReq.Header.Set("Content-Type", "application/json")
incompleteResp := httptest.NewRecorder()
authHandler.Register(incompleteResp, incompleteReq)
if incompleteResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for missing fields, got %d", incompleteResp.Code)
})
if incompleteResponse.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for missing fields, got %d", incompleteResponse.Code)
}
var incompleteErrorResp map[string]any
if err := json.Unmarshal(incompleteResp.Body.Bytes(), &incompleteErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := incompleteErrorResp["success"].(bool); success {
incompleteErrorResponse := assertJSONResponse(t, incompleteResponse, http.StatusBadRequest)
if success, _ := incompleteErrorResponse["success"].(bool); success {
t.Error("Expected error response to have success=false")
}
if errorMsg, ok := incompleteErrorResp["error"].(string); !ok || errorMsg == "" {
if errorMsg, ok := incompleteErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain validation error message")
}
})
@@ -186,28 +132,17 @@ func TestIntegration_Handlers(t *testing.T) {
emailSender.Reset()
user := createAuthenticatedUser(t, authService, userRepo, "post_user", "post@example.com")
postData := map[string]string{
postResponse := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "Handler Test Post",
"url": "https://example.com/handler-test",
"content": "This is a handler test post",
}
postBody, _ := json.Marshal(postData)
postReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(postBody))
postReq.Header.Set("Content-Type", "application/json")
postReq.Header.Set("Authorization", "Bearer "+user.Token)
postReq = testutils.WithUserContext(postReq, middleware.UserIDKey, user.User.ID)
postResp := httptest.NewRecorder()
postHandler.CreatePost(postResp, postReq)
if postResp.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", postResp.Code)
}, user, nil)
if postResponse.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", postResponse.Code)
}
var postResult map[string]any
if err := json.Unmarshal(postResp.Body.Bytes(), &postResult); err != nil {
t.Fatalf("Failed to decode post response: %v", err)
}
postDetails, ok := postResult["data"].(map[string]any)
postResult := assertJSONResponse(t, postResponse, http.StatusCreated)
postDetails, ok := getDataFromResponse(postResult)
if !ok {
t.Fatalf("Expected data object in post response, got %v", postResult)
}
@@ -216,87 +151,49 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatal("Expected post ID in response")
}
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", int(postID)), nil)
getReq = testutils.WithURLParams(getReq, map[string]string{"id": fmt.Sprintf("%d", int(postID))})
getResp := httptest.NewRecorder()
postHandler.GetPost(getResp, getReq)
if getResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getResp.Code)
getResponse := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", int(postID)))
if getResponse.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getResponse.Code)
}
postsReq := httptest.NewRequest("GET", "/api/posts", nil)
postsResp := httptest.NewRecorder()
postHandler.GetPosts(postsResp, postsReq)
if postsResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", postsResp.Code)
postsResponse := makeGetRequest(t, ctx.Router, "/api/posts")
if postsResponse.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", postsResponse.Code)
}
searchReq := httptest.NewRequest("GET", "/api/posts/search?q=handler", nil)
searchResp := httptest.NewRecorder()
postHandler.SearchPosts(searchResp, searchReq)
if searchResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", searchResp.Code)
searchResponse := makeGetRequest(t, ctx.Router, "/api/posts/search?q=handler")
if searchResponse.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", searchResponse.Code)
}
})
t.Run("Post_Handler_Security_Validation", func(t *testing.T) {
emailSender.Reset()
postData := map[string]string{
postResponse := makePostRequestWithJSON(t, ctx.Router, "/api/posts", map[string]any{
"title": "Unauthorized Post",
"url": "https://example.com/unauthorized",
"content": "This should fail",
}
postBody, _ := json.Marshal(postData)
postReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(postBody))
postReq.Header.Set("Content-Type", "application/json")
postResp := httptest.NewRecorder()
postHandler.CreatePost(postResp, postReq)
if postResp.Code != http.StatusUnauthorized {
t.Errorf("Expected status 401 for unauthenticated post creation, got %d", postResp.Code)
}
var authErrorResp map[string]any
if err := json.Unmarshal(postResp.Body.Bytes(), &authErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := authErrorResp["success"].(bool); success {
})
authErrorResponse := assertJSONResponse(t, postResponse, http.StatusUnauthorized)
if success, _ := authErrorResponse["success"].(bool); success {
t.Error("Expected error response to have success=false")
}
if errorMsg, ok := authErrorResp["error"].(string); !ok || errorMsg == "" {
if errorMsg, ok := authErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain authentication error message")
}
user := createAuthenticatedUser(t, authService, userRepo, "security_user", "security@example.com")
invalidData := map[string]string{
invalidResponse := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "",
"url": "not-a-url",
"content": "Invalid post",
}
invalidBody, _ := json.Marshal(invalidData)
invalidReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(invalidBody))
invalidReq.Header.Set("Content-Type", "application/json")
invalidReq.Header.Set("Authorization", "Bearer "+user.Token)
invalidReq = testutils.WithUserContext(invalidReq, middleware.UserIDKey, user.User.ID)
invalidResp := httptest.NewRecorder()
postHandler.CreatePost(invalidResp, invalidReq)
if invalidResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid post data, got %d", invalidResp.Code)
}
var postValidationErrorResp map[string]any
if err := json.Unmarshal(invalidResp.Body.Bytes(), &postValidationErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := postValidationErrorResp["success"].(bool); success {
}, user, nil)
postValidationErrorResponse := assertJSONResponse(t, invalidResponse, http.StatusBadRequest)
if success, _ := postValidationErrorResponse["success"].(bool); success {
t.Error("Expected error response to have success=false")
}
if errorMsg, ok := postValidationErrorResp["error"].(string); !ok || errorMsg == "" {
if errorMsg, ok := postValidationErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain validation error message")
}
})
@@ -306,161 +203,100 @@ func TestIntegration_Handlers(t *testing.T) {
user := createAuthenticatedUser(t, authService, userRepo, "vote_handler_user", "vote_handler@example.com")
post := testutils.CreatePostWithRepo(t, postRepo, user.User.ID, "Vote Handler Test Post", "https://example.com/vote-handler")
voteData := map[string]string{
"type": "up",
}
voteBody, _ := json.Marshal(voteData)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBody))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteResp := httptest.NewRecorder()
voteResponse := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, voteResponse, http.StatusOK)
voteHandler.CastVote(voteResp, voteReq)
if voteResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", voteResp.Code)
}
getVoteResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, getVoteResponse, http.StatusOK)
getVoteReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
getVoteReq.Header.Set("Authorization", "Bearer "+user.Token)
getVoteReq = testutils.WithUserContext(getVoteReq, middleware.UserIDKey, user.User.ID)
getVoteReq = testutils.WithURLParams(getVoteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVoteResp := httptest.NewRecorder()
getPostVotesResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/votes", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, getPostVotesResponse, http.StatusOK)
voteHandler.GetUserVote(getVoteResp, getVoteReq)
if getVoteResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getVoteResp.Code)
}
getPostVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getPostVotesReq.Header.Set("Authorization", "Bearer "+user.Token)
getPostVotesReq = testutils.WithUserContext(getPostVotesReq, middleware.UserIDKey, user.User.ID)
getPostVotesReq = testutils.WithURLParams(getPostVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getPostVotesResp := httptest.NewRecorder()
voteHandler.GetPostVotes(getPostVotesResp, getPostVotesReq)
if getPostVotesResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getPostVotesResp.Code)
}
removeVoteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
removeVoteReq.Header.Set("Authorization", "Bearer "+user.Token)
removeVoteReq = testutils.WithUserContext(removeVoteReq, middleware.UserIDKey, user.User.ID)
removeVoteReq = testutils.WithURLParams(removeVoteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
removeVoteResp := httptest.NewRecorder()
voteHandler.RemoveVote(removeVoteResp, removeVoteReq)
if removeVoteResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", removeVoteResp.Code)
}
removeVoteResponse := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, removeVoteResponse, http.StatusOK)
})
t.Run("User_Handler_Complete_Workflow", func(t *testing.T) {
emailSender.Reset()
user := createAuthenticatedUser(t, authService, userRepo, "user_handler_user", "user_handler@example.com")
usersReq := httptest.NewRequest("GET", "/api/users", nil)
usersReq.Header.Set("Authorization", "Bearer "+user.Token)
usersReq = testutils.WithUserContext(usersReq, middleware.UserIDKey, user.User.ID)
usersResp := httptest.NewRecorder()
usersResponse := makeAuthenticatedGetRequest(t, ctx.Router, "/api/users", user, nil)
assertStatus(t, usersResponse, http.StatusOK)
userHandler.GetUsers(usersResp, usersReq)
if usersResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", usersResp.Code)
}
getUserResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
assertStatus(t, getUserResponse, http.StatusOK)
getUserReq := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d", user.User.ID), nil)
getUserReq.Header.Set("Authorization", "Bearer "+user.Token)
getUserReq = testutils.WithUserContext(getUserReq, middleware.UserIDKey, user.User.ID)
getUserReq = testutils.WithURLParams(getUserReq, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
getUserResp := httptest.NewRecorder()
userHandler.GetUser(getUserResp, getUserReq)
if getUserResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getUserResp.Code)
}
getUserPostsReq := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d/posts", user.User.ID), nil)
getUserPostsReq.Header.Set("Authorization", "Bearer "+user.Token)
getUserPostsReq = testutils.WithUserContext(getUserPostsReq, middleware.UserIDKey, user.User.ID)
getUserPostsReq = testutils.WithURLParams(getUserPostsReq, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
getUserPostsResp := httptest.NewRecorder()
userHandler.GetUserPosts(getUserPostsResp, getUserPostsReq)
if getUserPostsResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getUserPostsResp.Code)
}
getUserPostsResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d/posts", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
assertStatus(t, getUserPostsResponse, http.StatusOK)
})
t.Run("Error_Handling_Invalid_Requests", func(t *testing.T) {
invalidJSONReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer([]byte("invalid json")))
invalidJSONReq.Header.Set("Content-Type", "application/json")
invalidJSONResp := httptest.NewRecorder()
middleware.StopAllRateLimiters()
ctx.Suite.EmailSender.Reset()
authHandler.Register(invalidJSONResp, invalidJSONReq)
if invalidJSONResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid JSON, got %d", invalidJSONResp.Code)
}
var jsonErrorResp map[string]any
if err := json.Unmarshal(invalidJSONResp.Body.Bytes(), &jsonErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := jsonErrorResp["success"].(bool); success {
invalidJSONResponse := makeRequest(t, ctx.Router, "POST", "/api/auth/register", []byte("invalid json"), map[string]string{"Content-Type": "application/json"})
jsonErrorResponse := assertJSONResponse(t, invalidJSONResponse, http.StatusBadRequest)
if success, _ := jsonErrorResponse["success"].(bool); success {
t.Error("Expected error response to have success=false")
}
if errorMsg, ok := jsonErrorResp["error"].(string); !ok || errorMsg == "" {
if errorMsg, ok := jsonErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain JSON parsing error message")
}
missingCTData := map[string]string{
"username": "missing_ct_user",
"email": "missing_ct@example.com",
"username": uniqueTestUsername(t, "missing_ct"),
"email": uniqueTestEmail(t, "missing_ct"),
"password": "SecurePass123!",
}
missingCTBody, _ := json.Marshal(missingCTData)
missingCTReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(missingCTBody))
missingCTResp := httptest.NewRecorder()
missingCTRequest := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(missingCTBody))
missingCTResponse := httptest.NewRecorder()
authHandler.Register(missingCTResp, missingCTReq)
if missingCTResp.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", missingCTResp.Code)
ctx.Router.ServeHTTP(missingCTResponse, missingCTRequest)
if missingCTResponse.Code == http.StatusTooManyRequests {
var rateLimitResponse map[string]any
if err := json.Unmarshal(missingCTResponse.Body.Bytes(), &rateLimitResponse); err != nil {
t.Errorf("Rate limited but response is not valid JSON: %v", err)
} else {
t.Logf("Rate limit hit (expected in full test suite run), but request was processed correctly (not rejected as invalid JSON)")
}
} else if missingCTResponse.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", missingCTResponse.Code)
}
invalidEndpointReq := httptest.NewRequest("GET", "/api/invalid/endpoint", nil)
invalidEndpointResp := httptest.NewRecorder()
invalidEndpointRequest := httptest.NewRequest("GET", "/api/invalid/endpoint", nil)
invalidEndpointResponse := httptest.NewRecorder()
authHandler.Me(invalidEndpointResp, invalidEndpointReq)
if invalidEndpointResp.Code == http.StatusOK {
ctx.Router.ServeHTTP(invalidEndpointResponse, invalidEndpointRequest)
if invalidEndpointResponse.Code == http.StatusOK {
t.Error("Expected error for invalid endpoint")
}
})
t.Run("Security_Authentication_Bypass", func(t *testing.T) {
meReq := httptest.NewRequest("GET", "/api/auth/me", nil)
meResp := httptest.NewRecorder()
meRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
meResponse := httptest.NewRecorder()
authHandler.Me(meResp, meReq)
if meResp.Code == http.StatusOK {
ctx.Router.ServeHTTP(meResponse, meRequest)
if meResponse.Code == http.StatusOK {
t.Error("Expected error for unauthenticated request")
}
invalidTokenReq := httptest.NewRequest("GET", "/api/auth/me", nil)
invalidTokenReq.Header.Set("Authorization", "Bearer invalid-token")
invalidTokenResp := httptest.NewRecorder()
invalidTokenRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
invalidTokenRequest.Header.Set("Authorization", "Bearer invalid-token")
invalidTokenResponse := httptest.NewRecorder()
authHandler.Me(invalidTokenResp, invalidTokenReq)
if invalidTokenResp.Code == http.StatusOK {
ctx.Router.ServeHTTP(invalidTokenResponse, invalidTokenRequest)
if invalidTokenResponse.Code == http.StatusOK {
t.Error("Expected error for invalid token")
}
malformedTokenReq := httptest.NewRequest("GET", "/api/auth/me", nil)
malformedTokenReq.Header.Set("Authorization", "InvalidFormat token")
malformedTokenResp := httptest.NewRecorder()
malformedTokenRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
malformedTokenRequest.Header.Set("Authorization", "InvalidFormat token")
malformedTokenResponse := httptest.NewRecorder()
authHandler.Me(malformedTokenResp, malformedTokenReq)
if malformedTokenResp.Code == http.StatusOK {
ctx.Router.ServeHTTP(malformedTokenResponse, malformedTokenRequest)
if malformedTokenResponse.Code == http.StatusOK {
t.Error("Expected error for malformed token")
}
})
@@ -468,32 +304,21 @@ func TestIntegration_Handlers(t *testing.T) {
t.Run("Security_Input_Sanitization", func(t *testing.T) {
user := createAuthenticatedUser(t, authService, userRepo, "xss_user", "xss@example.com")
xssData := map[string]string{
xssResponse := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "<script>alert('xss')</script>",
"url": "https://example.com/xss",
"content": "XSS test content",
}
xssBody, _ := json.Marshal(xssData)
xssReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(xssBody))
xssReq.Header.Set("Content-Type", "application/json")
xssReq.Header.Set("Authorization", "Bearer "+user.Token)
xssReq = testutils.WithUserContext(xssReq, middleware.UserIDKey, user.User.ID)
xssResp := httptest.NewRecorder()
postHandler.CreatePost(xssResp, xssReq)
if xssResp.Code != http.StatusCreated {
t.Errorf("Expected status 201 for XSS sanitization, got %d", xssResp.Code)
"content": "<script>alert('xss')</script>",
}, user, nil)
if xssResponse.Code != http.StatusCreated {
t.Errorf("Expected status 201 for XSS sanitization, got %d", xssResponse.Code)
}
var xssResult map[string]any
if err := json.Unmarshal(xssResp.Body.Bytes(), &xssResult); err != nil {
t.Fatalf("Failed to decode XSS response: %v", err)
}
xssResult := assertJSONResponse(t, xssResponse, http.StatusCreated)
if success, _ := xssResult["success"].(bool); !success {
t.Error("Expected XSS response to have success=true")
}
data, ok := xssResult["data"].(map[string]any)
data, ok := getDataFromResponse(xssResult)
if !ok {
t.Fatalf("Expected data object in XSS response, got %T", xssResult["data"])
}
@@ -522,32 +347,21 @@ func TestIntegration_Handlers(t *testing.T) {
t.Errorf("Expected script tags to be HTML-escaped in content, got: %s", content)
}
sqlData := map[string]string{
sqlResponse := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "'; DROP TABLE posts; --",
"url": "https://example.com/sql",
"content": "SQL injection test",
}
sqlBody, _ := json.Marshal(sqlData)
sqlReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(sqlBody))
sqlReq.Header.Set("Content-Type", "application/json")
sqlReq.Header.Set("Authorization", "Bearer "+user.Token)
sqlReq = testutils.WithUserContext(sqlReq, middleware.UserIDKey, user.User.ID)
sqlResp := httptest.NewRecorder()
postHandler.CreatePost(sqlResp, sqlReq)
if sqlResp.Code != http.StatusCreated {
t.Errorf("Expected status 201 for SQL injection sanitization, got %d", sqlResp.Code)
}, user, nil)
if sqlResponse.Code != http.StatusCreated {
t.Errorf("Expected status 201 for SQL injection sanitization, got %d", sqlResponse.Code)
}
var sqlResult map[string]any
if err := json.Unmarshal(sqlResp.Body.Bytes(), &sqlResult); err != nil {
t.Fatalf("Failed to decode SQL response: %v", err)
}
sqlResult := assertJSONResponse(t, sqlResponse, http.StatusCreated)
if success, _ := sqlResult["success"].(bool); !success {
t.Error("Expected SQL response to have success=true")
}
sqlResponseData, ok := sqlResult["data"].(map[string]any)
sqlResponseData, ok := getDataFromResponse(sqlResult)
if !ok {
t.Fatalf("Expected data object in SQL response, got %T", sqlResult["data"])
}
@@ -579,63 +393,31 @@ func TestIntegration_Handlers(t *testing.T) {
t.Run("Authorization_User_Access_Control", func(t *testing.T) {
emailSender.Reset()
user1 := createAuthenticatedUser(t, authService, userRepo, "auth_user1", "auth1@example.com")
user2 := createAuthenticatedUser(t, authService, userRepo, "auth_user2", "auth2@example.com")
firstUser := createAuthenticatedUser(t, authService, userRepo, "auth_user1", "auth1@example.com")
secondUser := createAuthenticatedUser(t, authService, userRepo, "auth_user2", "auth2@example.com")
post := testutils.CreatePostWithRepo(t, postRepo, user1.User.ID, "Private Post", "https://example.com/private")
post := testutils.CreatePostWithRepo(t, postRepo, firstUser.User.ID, "Private Post", "https://example.com/private")
getPostReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil)
getPostReq.Header.Set("Authorization", "Bearer "+user2.Token)
getPostReq = testutils.WithUserContext(getPostReq, middleware.UserIDKey, user2.User.ID)
getPostReq = testutils.WithURLParams(getPostReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getPostResp := httptest.NewRecorder()
getPostResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
testutils.AssertHTTPStatus(t, getPostResponse, http.StatusOK)
postHandler.GetPost(getPostResp, getPostReq)
testutils.AssertHTTPStatus(t, getPostResp, http.StatusOK)
updateResponse := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{"title": "Updated Title"}, secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
testutils.AssertHTTPStatus(t, updateResponse, http.StatusForbidden)
updateData := map[string]string{
"title": "Updated Title",
}
updateBody, _ := json.Marshal(updateData)
updateReq := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(updateBody))
updateReq.Header.Set("Content-Type", "application/json")
updateReq.Header.Set("Authorization", "Bearer "+user2.Token)
updateReq = testutils.WithUserContext(updateReq, middleware.UserIDKey, user2.User.ID)
updateReq = testutils.WithURLParams(updateReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
updateResp := httptest.NewRecorder()
postHandler.UpdatePost(updateResp, updateReq)
testutils.AssertHTTPStatus(t, updateResp, http.StatusForbidden)
deleteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil)
deleteReq.Header.Set("Authorization", "Bearer "+user2.Token)
deleteReq = testutils.WithUserContext(deleteReq, middleware.UserIDKey, user2.User.ID)
deleteReq = testutils.WithURLParams(deleteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
deleteResp := httptest.NewRecorder()
postHandler.DeletePost(deleteResp, deleteReq)
testutils.AssertHTTPStatus(t, deleteResp, http.StatusForbidden)
deleteResponse := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
testutils.AssertHTTPStatus(t, deleteResponse, http.StatusForbidden)
})
t.Run("Authorization_Vote_Access_Control", func(t *testing.T) {
emailSender.Reset()
user1 := createAuthenticatedUser(t, authService, userRepo, "vote_auth_user1", "vote_auth1@example.com")
user2 := createAuthenticatedUser(t, authService, userRepo, "vote_auth_user2", "vote_auth2@example.com")
firstUser := createAuthenticatedUser(t, authService, userRepo, "vote_auth_user1", "vote_auth1@example.com")
secondUser := createAuthenticatedUser(t, authService, userRepo, "vote_auth_user2", "vote_auth2@example.com")
post := testutils.CreatePostWithRepo(t, postRepo, user1.User.ID, "Vote Auth Post", "https://example.com/vote-auth")
post := testutils.CreatePostWithRepo(t, postRepo, firstUser.User.ID, "Vote Auth Post", "https://example.com/vote-auth")
voteData := map[string]string{"type": "up"}
voteBody, _ := json.Marshal(voteData)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBody))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user2.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user2.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteResp := httptest.NewRecorder()
voteHandler.CastVote(voteResp, voteReq)
if voteResp.Code != http.StatusOK {
t.Errorf("Users should be able to vote on any post, got %d", voteResp.Code)
voteResponse := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
if voteResponse.Code != http.StatusOK {
t.Errorf("Users should be able to vote on any post, got %d", voteResponse.Code)
}
})
@@ -664,12 +446,8 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatalf("Failed to generate expired token: %v", err)
}
meReq := httptest.NewRequest("GET", "/api/auth/me", nil)
meReq.Header.Set("Authorization", "Bearer "+expiredToken)
meResp := httptest.NewRecorder()
authHandler.Me(meResp, meReq)
testutils.AssertHTTPStatus(t, meResp, http.StatusUnauthorized)
meResponse := makeRequest(t, ctx.Router, "GET", "/api/auth/me", nil, map[string]string{"Authorization": "Bearer " + expiredToken})
testutils.AssertHTTPStatus(t, meResponse, http.StatusUnauthorized)
})
t.Run("Authorization_Token_Tampering", func(t *testing.T) {
@@ -678,12 +456,12 @@ func TestIntegration_Handlers(t *testing.T) {
tamperedToken := user.Token[:len(user.Token)-5] + "XXXXX"
meReq := httptest.NewRequest("GET", "/api/auth/me", nil)
meReq.Header.Set("Authorization", "Bearer "+tamperedToken)
meResp := httptest.NewRecorder()
meRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
meRequest.Header.Set("Authorization", "Bearer "+tamperedToken)
meResponse := httptest.NewRecorder()
authHandler.Me(meResp, meReq)
testutils.AssertHTTPStatus(t, meResp, http.StatusUnauthorized)
ctx.Router.ServeHTTP(meResponse, meRequest)
testutils.AssertHTTPStatus(t, meResponse, http.StatusUnauthorized)
})
t.Run("Authorization_Session_Version_Mismatch", func(t *testing.T) {
@@ -711,12 +489,12 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatalf("Failed to generate invalid token: %v", err)
}
meReq := httptest.NewRequest("GET", "/api/auth/me", nil)
meReq.Header.Set("Authorization", "Bearer "+invalidToken)
meResp := httptest.NewRecorder()
meRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
meRequest.Header.Set("Authorization", "Bearer "+invalidToken)
meResponse := httptest.NewRecorder()
authHandler.Me(meResp, meReq)
testutils.AssertHTTPStatus(t, meResp, http.StatusUnauthorized)
ctx.Router.ServeHTTP(meResponse, meRequest)
testutils.AssertHTTPStatus(t, meResponse, http.StatusUnauthorized)
})
}
@@ -748,11 +526,9 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
}
voteService := services.NewVoteService(voteRepo, postRepo, db)
apiHandler := handlers.NewAPIHandlerWithMonitoring(testutils.AppTestConfig, postRepo, userRepo, voteService, db, monitor)
t.Run("Health endpoint includes database monitoring", func(t *testing.T) {
user := &database.User{
Username: "monitoring_user",
Email: "monitoring@example.com",
@@ -765,7 +541,6 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
recorder := httptest.NewRecorder()
apiHandler.GetHealth(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
var response map[string]any
@@ -777,7 +552,7 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
t.Error("Expected success to be true")
}
data, ok := response["data"].(map[string]any)
data, ok := getDataFromResponse(response)
if !ok {
t.Fatal("Expected data to be a map")
}
@@ -813,7 +588,6 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
recorder := httptest.NewRecorder()
apiHandler.GetMetrics(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
var response map[string]any
@@ -825,7 +599,7 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
t.Error("Expected success to be true")
}
data, ok := response["data"].(map[string]any)
data, ok := getDataFromResponse(response)
if !ok {
t.Fatal("Expected data to be a map")
}

View File

@@ -1,6 +1,7 @@
package integration
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
@@ -282,14 +283,14 @@ func setupPageHandlerTestContext(t *testing.T) *testContext {
func getCSRFToken(t *testing.T, router http.Handler, path string, cookies ...*http.Cookie) string {
t.Helper()
req := httptest.NewRequest("GET", path, nil)
request := httptest.NewRequest("GET", path, nil)
for _, cookie := range cookies {
req.AddCookie(cookie)
request.AddCookie(cookie)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
cookieList := rec.Result().Cookies()
cookieList := recorder.Result().Cookies()
for _, cookie := range cookieList {
if cookie.Name == "csrf_token" {
return cookie.Value
@@ -299,32 +300,32 @@ func getCSRFToken(t *testing.T, router http.Handler, path string, cookies ...*ht
return ""
}
func assertJSONResponse(t *testing.T, rec *httptest.ResponseRecorder, expectedStatus int) map[string]any {
func assertJSONResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedStatus int) map[string]any {
t.Helper()
if rec.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, rec.Code, rec.Body.String())
if recorder.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, recorder.Code, recorder.Body.String())
return nil
}
var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode response: %v. Body: %s", err, rec.Body.String())
if err := json.NewDecoder(recorder.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode response: %v. Body: %s", err, recorder.Body.String())
return nil
}
return response
}
func assertErrorResponse(t *testing.T, rec *httptest.ResponseRecorder, expectedStatus int) {
func assertErrorResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedStatus int) {
t.Helper()
if rec.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, rec.Code, rec.Body.String())
if recorder.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, recorder.Code, recorder.Body.String())
return
}
var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode error response: %v. Body: %s", err, rec.Body.String())
if err := json.NewDecoder(recorder.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode error response: %v. Body: %s", err, recorder.Body.String())
return
}
@@ -335,23 +336,23 @@ func assertErrorResponse(t *testing.T, rec *httptest.ResponseRecorder, expectedS
}
}
func assertStatus(t *testing.T, rec *httptest.ResponseRecorder, expectedStatus int) {
func assertStatus(t *testing.T, recorder *httptest.ResponseRecorder, expectedStatus int) {
t.Helper()
if rec.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, rec.Code, rec.Body.String())
if recorder.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, recorder.Code, recorder.Body.String())
}
}
func assertStatusRange(t *testing.T, rec *httptest.ResponseRecorder, minStatus, maxStatus int) {
func assertStatusRange(t *testing.T, recorder *httptest.ResponseRecorder, minStatus, maxStatus int) {
t.Helper()
if rec.Code < minStatus || rec.Code > maxStatus {
t.Errorf("Expected status between %d and %d, got %d. Body: %s", minStatus, maxStatus, rec.Code, rec.Body.String())
if recorder.Code < minStatus || recorder.Code > maxStatus {
t.Errorf("Expected status between %d and %d, got %d. Body: %s", minStatus, maxStatus, recorder.Code, recorder.Body.String())
}
}
func assertCookie(t *testing.T, rec *httptest.ResponseRecorder, name, expectedValue string) {
func assertCookie(t *testing.T, recorder *httptest.ResponseRecorder, name, expectedValue string) {
t.Helper()
cookies := rec.Result().Cookies()
cookies := recorder.Result().Cookies()
for _, cookie := range cookies {
if cookie.Name == name {
if expectedValue != "" && cookie.Value != expectedValue {
@@ -363,9 +364,9 @@ func assertCookie(t *testing.T, rec *httptest.ResponseRecorder, name, expectedVa
t.Errorf("Expected cookie %s not found", name)
}
func assertCookieCleared(t *testing.T, rec *httptest.ResponseRecorder, name string) {
func assertCookieCleared(t *testing.T, recorder *httptest.ResponseRecorder, name string) {
t.Helper()
cookies := rec.Result().Cookies()
cookies := recorder.Result().Cookies()
for _, cookie := range cookies {
if cookie.Name == name {
if cookie.Value != "" {
@@ -376,21 +377,17 @@ func assertCookieCleared(t *testing.T, rec *httptest.ResponseRecorder, name stri
}
}
func assertHeader(t *testing.T, rec *httptest.ResponseRecorder, name, expectedValue string) {
func assertHeader(t *testing.T, recorder *httptest.ResponseRecorder, name string) {
t.Helper()
actualValue := rec.Header().Get(name)
if expectedValue == "" {
if actualValue == "" {
t.Errorf("Expected header %s to be present", name)
}
} else if actualValue != expectedValue {
t.Errorf("Expected header %s=%s, got %s", name, expectedValue, actualValue)
actualValue := recorder.Header().Get(name)
if actualValue == "" {
t.Errorf("Expected header %s to be present", name)
}
}
func assertHeaderContains(t *testing.T, rec *httptest.ResponseRecorder, name, substring string) {
func assertHeaderContains(t *testing.T, recorder *httptest.ResponseRecorder, name, substring string) {
t.Helper()
actualValue := rec.Header().Get(name)
actualValue := recorder.Header().Get(name)
if !strings.Contains(actualValue, substring) {
t.Errorf("Expected header %s to contain %s, got %s", name, substring, actualValue)
}
@@ -450,3 +447,83 @@ func createUserWithCleanup(t *testing.T, ctx *testContext, username, email strin
})
return user
}
func makeRequest(t *testing.T, router http.Handler, method, path string, body []byte, headers map[string]string) *httptest.ResponseRecorder {
t.Helper()
var requestBody *bytes.Buffer
if body != nil {
requestBody = bytes.NewBuffer(body)
} else {
requestBody = bytes.NewBuffer(nil)
}
request := httptest.NewRequest(method, path, requestBody)
for key, value := range headers {
request.Header.Set(key, value)
}
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
return recorder
}
func makeAuthenticatedRequest(t *testing.T, router http.Handler, method, path string, body []byte, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
var requestBody *bytes.Buffer
if body != nil {
requestBody = bytes.NewBuffer(body)
} else {
requestBody = bytes.NewBuffer(nil)
}
request := httptest.NewRequest(method, path, requestBody)
request.Header.Set("Authorization", "Bearer "+user.Token)
if body != nil {
request.Header.Set("Content-Type", "application/json")
}
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
if urlParams != nil {
request = testutils.WithURLParams(request, urlParams)
}
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
return recorder
}
func makeGetRequest(t *testing.T, router http.Handler, path string) *httptest.ResponseRecorder {
t.Helper()
return makeRequest(t, router, "GET", path, nil, nil)
}
func makeAuthenticatedGetRequest(t *testing.T, router http.Handler, path string, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
return makeAuthenticatedRequest(t, router, "GET", path, nil, user, urlParams)
}
func makePostRequest(t *testing.T, router http.Handler, path string, body map[string]any, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
bodyBytes, _ := json.Marshal(body)
return makeAuthenticatedRequest(t, router, "POST", path, bodyBytes, user, urlParams)
}
func makePutRequest(t *testing.T, router http.Handler, path string, body map[string]any, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
bodyBytes, _ := json.Marshal(body)
return makeAuthenticatedRequest(t, router, "PUT", path, bodyBytes, user, urlParams)
}
func makeDeleteRequest(t *testing.T, router http.Handler, path string, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
return makeAuthenticatedRequest(t, router, "DELETE", path, nil, user, urlParams)
}
func makePostRequestWithJSON(t *testing.T, router http.Handler, path string, body map[string]any) *httptest.ResponseRecorder {
t.Helper()
bodyBytes, _ := json.Marshal(body)
return makeRequest(t, router, "POST", path, bodyBytes, map[string]string{"Content-Type": "application/json"})
}
func getDataFromResponse(response map[string]any) (map[string]any, bool) {
if response == nil {
return nil, false
}
data, ok := response["data"].(map[string]any)
return data, ok
}

View File

@@ -22,26 +22,26 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, authService, ctx.Suite.UserRepo, "settings_email_user", "settings_email@example.com")
getReq := httptest.NewRequest("GET", "/settings", nil)
getReq.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
getRequest := httptest.NewRequest("GET", "/settings", nil)
getRequest.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
getRecorder := httptest.NewRecorder()
router.ServeHTTP(getRecorder, getRequest)
csrfToken := getCSRFToken(t, router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token})
reqBody := url.Values{}
reqBody.Set("email", "newemail@example.com")
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("email", "newemail@example.com")
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/settings/email", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/settings/email", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
})
t.Run("Settings_Username_Update_Form", func(t *testing.T) {
@@ -51,19 +51,19 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token})
reqBody := url.Values{}
reqBody.Set("username", "new_username")
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("username", "new_username")
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/settings/username", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/settings/username", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
})
t.Run("Settings_Password_Update_Form", func(t *testing.T) {
@@ -74,20 +74,20 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, freshCtx.Router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token})
reqBody := url.Values{}
reqBody.Set("current_password", "SecurePass123!")
reqBody.Set("new_password", "NewSecurePass123!")
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("current_password", "SecurePass123!")
requestBody.Set("new_password", "NewSecurePass123!")
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/settings/password", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/settings/password", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req)
freshCtx.Router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
})
t.Run("Logout_Page_Handler", func(t *testing.T) {
@@ -98,19 +98,19 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, freshCtx.Router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token})
reqBody := url.Values{}
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/logout", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/logout", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req)
freshCtx.Router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusSeeOther)
assertCookieCleared(t, rec, "auth_token")
assertStatus(t, recorder, http.StatusSeeOther)
assertCookieCleared(t, recorder, "auth_token")
})
t.Run("Resend_Verification_Page_Handler", func(t *testing.T) {
@@ -120,18 +120,18 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, freshCtx.Router, "/resend-verification")
reqBody := url.Values{}
reqBody.Set("email", "resend_page@example.com")
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("email", "resend_page@example.com")
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/resend-verification", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/resend-verification", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req)
freshCtx.Router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
})
t.Run("Post_Vote_Page_Handler", func(t *testing.T) {
@@ -142,26 +142,26 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
post := testutils.CreatePostWithRepo(t, freshCtx.Suite.PostRepo, user.User.ID, "Vote Page Test", "https://example.com/vote-page")
getReq := httptest.NewRequest("GET", fmt.Sprintf("/posts/%d", post.ID), nil)
getRec := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(getRec, getReq)
getRequest := httptest.NewRequest("GET", fmt.Sprintf("/posts/%d", post.ID), nil)
getRecorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(getRecorder, getRequest)
csrfToken := getCSRFToken(t, freshCtx.Router, fmt.Sprintf("/posts/%d", post.ID))
reqBody := url.Values{}
reqBody.Set("action", "up")
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("action", "up")
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", fmt.Sprintf("/posts/%d/vote", post.ID), strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", fmt.Sprintf("/posts/%d/vote", post.ID), strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
request = testutils.WithURLParams(request, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req)
freshCtx.Router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
})
t.Run("Login_Page_Handler_Workflow", func(t *testing.T) {
@@ -172,20 +172,20 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, freshCtx.Router, "/login")
reqBody := url.Values{}
reqBody.Set("username", "login_page_user")
reqBody.Set("password", "SecurePass123!")
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("username", "login_page_user")
requestBody.Set("password", "SecurePass123!")
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/login", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/login", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req)
freshCtx.Router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusSeeOther)
assertCookie(t, rec, "auth_token", "")
assertStatus(t, recorder, http.StatusSeeOther)
assertCookie(t, recorder, "auth_token", "")
})
t.Run("Email_Confirmation_Page_Handler", func(t *testing.T) {
@@ -198,11 +198,11 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
token = "test-token"
}
req := httptest.NewRequest("GET", "/confirm?token="+url.QueryEscape(token), nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/confirm?token="+url.QueryEscape(token), nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
})
}

View File

@@ -16,63 +16,62 @@ func TestIntegration_PageHandler(t *testing.T) {
router := ctx.Router
t.Run("Home_Page_Renders", func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, recorder, http.StatusOK)
if !strings.Contains(rec.Body.String(), "<html") {
if !strings.Contains(recorder.Body.String(), "<html") {
t.Error("Expected HTML content")
}
})
t.Run("Login_Form_Renders", func(t *testing.T) {
req := httptest.NewRequest("GET", "/login", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/login", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, recorder, http.StatusOK)
body := rec.Body.String()
body := recorder.Body.String()
if !strings.Contains(body, "login") && !strings.Contains(body, "Login") {
t.Error("Expected login form content")
}
})
t.Run("Register_Form_Renders", func(t *testing.T) {
req := httptest.NewRequest("GET", "/register", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/register", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatus(t, recorder, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
body := rec.Body.String()
body := recorder.Body.String()
if !strings.Contains(body, "register") && !strings.Contains(body, "Register") {
t.Error("Expected register form content")
}
})
t.Run("PageHandler_With_CSRF_Token", func(t *testing.T) {
req := httptest.NewRequest("GET", "/register", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/register", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertCookie(t, rec, "csrf_token", "")
assertCookie(t, recorder, "csrf_token", "")
})
t.Run("PageHandler_Form_Submission", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
getReq := httptest.NewRequest("GET", "/register", nil)
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
getRequest := httptest.NewRequest("GET", "/register", nil)
getRecorder := httptest.NewRecorder()
router.ServeHTTP(getRecorder, getRequest)
cookies := getRec.Result().Cookies()
cookies := getRecorder.Result().Cookies()
var csrfCookie *http.Cookie
for _, cookie := range cookies {
if cookie.Name == "csrf_token" {
@@ -85,33 +84,33 @@ func TestIntegration_PageHandler(t *testing.T) {
t.Fatal("Expected CSRF cookie")
}
reqBody := url.Values{}
reqBody.Set("username", "page_form_user")
reqBody.Set("email", "page_form@example.com")
reqBody.Set("password", "SecurePass123!")
reqBody.Set("csrf_token", csrfCookie.Value)
requestBody := url.Values{}
requestBody.Set("username", "page_form_user")
requestBody.Set("email", "page_form@example.com")
requestBody.Set("password", "SecurePass123!")
requestBody.Set("csrf_token", csrfCookie.Value)
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
})
t.Run("PageHandler_Authenticated_Access", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createUserWithCleanup(t, ctx, "page_auth_user", "page_auth@example.com")
req := httptest.NewRequest("GET", "/settings", nil)
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/settings", nil)
request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, recorder, http.StatusOK)
})
t.Run("PageHandler_Post_Display", func(t *testing.T) {
@@ -120,34 +119,34 @@ func TestIntegration_PageHandler(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Page Test Post", "https://example.com/page-test")
req := httptest.NewRequest("GET", "/posts/"+fmt.Sprintf("%d", post.ID), nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/posts/"+fmt.Sprintf("%d", post.ID), nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, recorder, http.StatusOK)
body := rec.Body.String()
body := recorder.Body.String()
if !strings.Contains(body, "Page Test Post") {
t.Error("Expected post title in page")
}
})
t.Run("PageHandler_Search_Page", func(t *testing.T) {
req := httptest.NewRequest("GET", "/search?q=test", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/search?q=test", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, recorder, http.StatusOK)
})
t.Run("PageHandler_Error_Handling", func(t *testing.T) {
req := httptest.NewRequest("GET", "/nonexistent", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/nonexistent", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusNotFound)
assertStatus(t, recorder, http.StatusNotFound)
})
}

View File

@@ -1,8 +1,6 @@
package integration
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
@@ -32,17 +30,12 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
t.Fatalf("Failed to create user: %v", err)
}
reqBody := map[string]string{
reqBody := map[string]any{
"username_or_email": "reset_user",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/forgot-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, router, "/api/auth/forgot-password", reqBody)
router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, request, http.StatusOK)
if response != nil {
if success, ok := response["success"].(bool); !ok || !success {
t.Error("Expected success=true")
@@ -77,18 +70,13 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
t.Fatal("Expected password reset token")
}
reqBody := map[string]string{
reqBody := map[string]any{
"token": resetToken,
"new_password": "NewPassword123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, router, "/api/auth/reset-password", reqBody)
router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
loginResult, err := ctx.AuthService.Login("reset_complete_user", "NewPassword123!")
if err != nil {
@@ -120,14 +108,14 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
reqBody := url.Values{}
reqBody.Set("username_or_email", "page_reset_user")
reqBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/forgot-password", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/forgot-password", strings.NewReader(reqBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
pageRouter.ServeHTTP(rec, req)
pageRouter.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
resetToken := pageCtx.Suite.EmailSender.PasswordResetToken()
if resetToken == "" {
@@ -166,33 +154,23 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
t.Fatalf("Failed to update user: %v", err)
}
reqBody := map[string]string{
reqBody := map[string]any{
"token": resetToken,
"new_password": "NewPassword123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, router, "/api/auth/reset-password", reqBody)
router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusBadRequest)
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("PasswordReset_InvalidToken", func(t *testing.T) {
reqBody := map[string]string{
reqBody := map[string]any{
"token": "invalid-token",
"new_password": "NewPassword123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, router, "/api/auth/reset-password", reqBody)
router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusBadRequest)
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("PasswordReset_WeakPassword", func(t *testing.T) {
@@ -214,18 +192,13 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
resetToken := ctx.Suite.EmailSender.PasswordResetToken()
reqBody := map[string]string{
reqBody := map[string]any{
"token": resetToken,
"new_password": "123",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, router, "/api/auth/reset-password", reqBody)
router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusBadRequest)
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("PasswordReset_EmailIntegration", func(t *testing.T) {
@@ -243,17 +216,12 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
t.Fatalf("Failed to create user: %v", err)
}
reqBody := map[string]string{
reqBody := map[string]any{
"username_or_email": "email_reset@example.com",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/forgot-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, freshCtx.Router, "/api/auth/forgot-password", reqBody)
freshCtx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
resetToken := freshCtx.Suite.EmailSender.PasswordResetToken()
if resetToken == "" {

View File

@@ -51,24 +51,24 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 2; i++ {
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
}
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests)
assertErrorResponse(t, recorder, http.StatusTooManyRequests)
assertHeader(t, rec, "Retry-After", "")
assertHeader(t, recorder, "Retry-After")
var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err == nil {
if err := json.NewDecoder(recorder.Body).Decode(&response); err == nil {
if _, exists := response["retry_after"]; !exists {
t.Error("Expected retry_after in response")
}
@@ -81,17 +81,17 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 5; i++ {
req := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("GET", "/api/posts", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
}
req := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/api/posts", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests)
assertErrorResponse(t, recorder, http.StatusTooManyRequests)
})
t.Run("Health_RateLimit_Enforced", func(t *testing.T) {
@@ -100,17 +100,17 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 3; i++ {
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("GET", "/health", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
}
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/health", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests)
assertErrorResponse(t, recorder, http.StatusTooManyRequests)
})
t.Run("Metrics_RateLimit_Enforced", func(t *testing.T) {
@@ -119,17 +119,17 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 2; i++ {
req := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("GET", "/metrics", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
}
req := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/metrics", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests)
assertErrorResponse(t, recorder, http.StatusTooManyRequests)
})
t.Run("RateLimit_Different_Endpoints_Independent", func(t *testing.T) {
@@ -139,17 +139,17 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 2; i++ {
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
}
req := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("GET", "/api/posts", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, recorder, http.StatusOK)
})
t.Run("RateLimit_With_Authentication", func(t *testing.T) {
@@ -166,20 +166,20 @@ func TestIntegration_RateLimiting(t *testing.T) {
user := createAuthenticatedUser(t, authService, suite.UserRepo, uniqueTestUsername(t, "ratelimit_auth"), uniqueTestEmail(t, "ratelimit_auth"))
for i := 0; i < 3; i++ {
req := httptest.NewRequest("GET", "/api/auth/me", nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("GET", "/api/auth/me", nil)
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
}
req := httptest.NewRequest("GET", "/api/auth/me", nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/api/auth/me", nil)
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests)
assertErrorResponse(t, recorder, http.StatusTooManyRequests)
})
}

View File

@@ -17,35 +17,29 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
router := ctx.Router
t.Run("SecurityHeaders_Present", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/health")
router.ServeHTTP(rec, req)
assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
assertHeader(t, rec, "X-Content-Type-Options", "")
assertHeader(t, rec, "X-Frame-Options", "")
assertHeader(t, rec, "X-XSS-Protection", "")
assertHeader(t, request, "X-Content-Type-Options")
assertHeader(t, request, "X-Frame-Options")
assertHeader(t, request, "X-XSS-Protection")
})
t.Run("CORS_Headers_Present", func(t *testing.T) {
req := httptest.NewRequest("OPTIONS", "/api/posts", nil)
req.Header.Set("Origin", "http://localhost:3000")
rec := httptest.NewRecorder()
request := httptest.NewRequest("OPTIONS", "/api/posts", nil)
request.Header.Set("Origin", "http://localhost:3000")
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertHeader(t, rec, "Access-Control-Allow-Origin", "")
assertHeader(t, recorder, "Access-Control-Allow-Origin")
})
t.Run("Logging_Middleware_Executes", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/health")
router.ServeHTTP(rec, req)
if rec.Code == 0 {
if request.Code == 0 {
t.Error("Expected logging middleware to execute")
}
})
@@ -53,27 +47,24 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
t.Run("RequestSizeLimit_Enforced", func(t *testing.T) {
user := createUserWithCleanup(t, ctx, "size_limit_user", "size_limit@example.com")
largeBody := strings.Repeat("a", 10*1024*1024)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBufferString(largeBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBufferString(largeBody))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Code != http.StatusRequestEntityTooLarge && rec.Code != http.StatusBadRequest {
t.Errorf("Expected status 413 or 400 for oversized request, got %d. Body: %s", rec.Code, rec.Body.String())
if recorder.Code != http.StatusRequestEntityTooLarge && recorder.Code != http.StatusBadRequest {
t.Errorf("Expected status 413 or 400 for oversized request, got %d. Body: %s", recorder.Code, recorder.Body.String())
}
})
t.Run("DBMonitoring_Active", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := makeGetRequest(t, router, "/health")
var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err == nil {
if err := json.NewDecoder(request.Body).Decode(&response); err == nil {
if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["database_stats"]; !exists {
t.Error("Expected database_stats in health response")
@@ -83,12 +74,9 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
})
t.Run("Metrics_Middleware_Executes", func(t *testing.T) {
req := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/metrics")
router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, request, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["database"]; !exists {
@@ -99,34 +87,25 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
})
t.Run("StaticFiles_Served", func(t *testing.T) {
req := httptest.NewRequest("GET", "/robots.txt", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/robots.txt")
router.ServeHTTP(rec, req)
assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
if !strings.Contains(rec.Body.String(), "User-agent") {
if !strings.Contains(request.Body.String(), "User-agent") {
t.Error("Expected robots.txt content")
}
})
t.Run("API_Routes_Accessible", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/api/posts")
router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
})
t.Run("Health_Endpoint_Accessible", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/health")
router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, request, http.StatusOK)
if response != nil {
if success, ok := response["success"].(bool); !ok || !success {
t.Error("Expected success=true in health response")
@@ -135,40 +114,33 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
})
t.Run("Middleware_Order_Correct", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/api/posts")
router.ServeHTTP(rec, req)
assertHeader(t, request, "X-Content-Type-Options")
assertHeader(t, rec, "X-Content-Type-Options", "")
if rec.Code == 0 {
if request.Code == 0 {
t.Error("Response should have status code")
}
})
t.Run("Compression_Middleware_Active", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil)
req.Header.Set("Accept-Encoding", "gzip")
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/api/posts", nil)
request.Header.Set("Accept-Encoding", "gzip")
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Header().Get("Content-Encoding") == "" {
if recorder.Header().Get("Content-Encoding") == "" {
t.Log("Compression may not be applied to small responses")
}
})
t.Run("Cache_Middleware_Active", func(t *testing.T) {
req1 := httptest.NewRequest("GET", "/api/posts", nil)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
firstRequest := makeGetRequest(t, router, "/api/posts")
req2 := httptest.NewRequest("GET", "/api/posts", nil)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
secondRequest := makeGetRequest(t, router, "/api/posts")
if rec1.Code != rec2.Code {
if firstRequest.Code != secondRequest.Code {
t.Error("Cached responses should have same status")
}
})
@@ -177,35 +149,23 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createUserWithCleanup(t, ctx, "auth_middleware_user", "auth_middleware@example.com")
req := httptest.NewRequest("GET", "/api/auth/me", nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, router, "/api/auth/me", user, nil)
router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
})
t.Run("RateLimit_Middleware_Integration", func(t *testing.T) {
rateLimitCtx := setupTestContext(t)
rateLimitRouter := rateLimitCtx.Router
for i := 0; i < 3; i++ {
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
rateLimitRouter.ServeHTTP(rec, req)
for range 3 {
request := makePostRequestWithJSON(t, rateLimitRouter, "/api/auth/login", map[string]any{"username": "test", "password": "test"})
_ = request
}
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, rateLimitRouter, "/api/auth/login", map[string]any{"username": "test", "password": "test"})
rateLimitRouter.ServeHTTP(rec, req)
if rec.Code == http.StatusTooManyRequests {
if request.Code == http.StatusTooManyRequests {
t.Log("Rate limiting is working")
}
})

View File

@@ -21,60 +21,60 @@ func TestIntegration_SessionManagement(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "session_pass_user", "session_pass@example.com")
req1 := httptest.NewRequest("GET", "/api/auth/me", nil)
req1.Header.Set("Authorization", "Bearer "+user.Token)
req1 = testutils.WithUserContext(req1, middleware.UserIDKey, user.User.ID)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
firstRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
firstRequest.Header.Set("Authorization", "Bearer "+user.Token)
firstRequest = testutils.WithUserContext(firstRequest, middleware.UserIDKey, user.User.ID)
firstRecorder := httptest.NewRecorder()
router.ServeHTTP(firstRecorder, firstRequest)
assertStatus(t, rec1, http.StatusOK)
assertStatus(t, firstRecorder, http.StatusOK)
reqBody := map[string]string{
requestBody := map[string]string{
"current_password": "SecurePass123!",
"new_password": "NewSecurePass123!",
}
body, _ := json.Marshal(reqBody)
req2 := httptest.NewRequest("PUT", "/api/auth/password", bytes.NewBuffer(body))
req2.Header.Set("Content-Type", "application/json")
req2.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
body, _ := json.Marshal(requestBody)
secondRequest := httptest.NewRequest("PUT", "/api/auth/password", bytes.NewBuffer(body))
secondRequest.Header.Set("Content-Type", "application/json")
secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondRequest)
assertStatus(t, rec2, http.StatusOK)
assertStatus(t, secondRecorder, http.StatusOK)
req3 := httptest.NewRequest("GET", "/api/auth/me", nil)
req3.Header.Set("Authorization", "Bearer "+user.Token)
req3 = testutils.WithUserContext(req3, middleware.UserIDKey, user.User.ID)
rec3 := httptest.NewRecorder()
router.ServeHTTP(rec3, req3)
thirdRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
thirdRequest.Header.Set("Authorization", "Bearer "+user.Token)
thirdRequest = testutils.WithUserContext(thirdRequest, middleware.UserIDKey, user.User.ID)
thirdRecorder := httptest.NewRecorder()
router.ServeHTTP(thirdRecorder, thirdRequest)
assertErrorResponse(t, rec3, http.StatusUnauthorized)
assertErrorResponse(t, thirdRecorder, http.StatusUnauthorized)
})
t.Run("Session_Invalidation_On_Account_Lock", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "session_lock_user", "session_lock@example.com")
req1 := httptest.NewRequest("GET", "/api/auth/me", nil)
req1.Header.Set("Authorization", "Bearer "+user.Token)
req1 = testutils.WithUserContext(req1, middleware.UserIDKey, user.User.ID)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
firstRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
firstRequest.Header.Set("Authorization", "Bearer "+user.Token)
firstRequest = testutils.WithUserContext(firstRequest, middleware.UserIDKey, user.User.ID)
firstRecorder := httptest.NewRecorder()
router.ServeHTTP(firstRecorder, firstRequest)
assertStatus(t, rec1, http.StatusOK)
assertStatus(t, firstRecorder, http.StatusOK)
if err := ctx.Suite.UserRepo.Lock(user.User.ID); err != nil {
t.Fatalf("Failed to lock user: %v", err)
}
req2 := httptest.NewRequest("GET", "/api/auth/me", nil)
req2.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
secondRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondRequest)
assertErrorResponse(t, rec2, http.StatusUnauthorized)
assertErrorResponse(t, secondRecorder, http.StatusUnauthorized)
})
t.Run("Refresh_Token_Revocation", func(t *testing.T) {
@@ -90,48 +90,48 @@ func TestIntegration_SessionManagement(t *testing.T) {
t.Fatal("Expected refresh token")
}
reqBody := map[string]string{
requestBody := map[string]string{
"refresh_token": loginResult.RefreshToken,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
body, _ := json.Marshal(requestBody)
request := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, recorder, http.StatusOK)
if err := ctx.AuthService.RevokeRefreshToken(loginResult.RefreshToken); err != nil {
t.Fatalf("Failed to revoke token: %v", err)
}
req2 := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body))
req2.Header.Set("Content-Type", "application/json")
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
secondRequest := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body))
secondRequest.Header.Set("Content-Type", "application/json")
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondRequest)
assertErrorResponse(t, rec2, http.StatusUnauthorized)
assertErrorResponse(t, secondRecorder, http.StatusUnauthorized)
})
t.Run("Multiple_Sessions_Independent", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user1 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "multi_session_user1", "multi_session1@example.com")
user2 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "multi_session_user2", "multi_session2@example.com")
firstUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "multi_session_user1", "multi_session1@example.com")
secondUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "multi_session_user2", "multi_session2@example.com")
req1 := httptest.NewRequest("GET", "/api/auth/me", nil)
req1.Header.Set("Authorization", "Bearer "+user1.Token)
req1 = testutils.WithUserContext(req1, middleware.UserIDKey, user1.User.ID)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
firstRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
firstRequest.Header.Set("Authorization", "Bearer "+firstUser.Token)
firstRequest = testutils.WithUserContext(firstRequest, middleware.UserIDKey, firstUser.User.ID)
firstRecorder := httptest.NewRecorder()
router.ServeHTTP(firstRecorder, firstRequest)
req2 := httptest.NewRequest("GET", "/api/auth/me", nil)
req2.Header.Set("Authorization", "Bearer "+user2.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user2.User.ID)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
secondRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
secondRequest.Header.Set("Authorization", "Bearer "+secondUser.Token)
secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, secondUser.User.ID)
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondRequest)
assertStatus(t, rec1, http.StatusOK)
assertStatus(t, rec2, http.StatusOK)
assertStatus(t, firstRecorder, http.StatusOK)
assertStatus(t, secondRecorder, http.StatusOK)
})
}
@@ -144,17 +144,17 @@ func TestIntegration_AccountDeletion(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "del_flow_user", "del_flow@example.com")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Test Post", "https://example.com")
reqBody := map[string]string{}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
requestBody := map[string]string{}
body, _ := json.Marshal(requestBody)
request := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, recorder, http.StatusOK)
if response == nil {
return
}
@@ -171,13 +171,13 @@ func TestIntegration_AccountDeletion(t *testing.T) {
"token": deletionToken,
}
confirmBodyBytes, _ := json.Marshal(confirmBody)
confirmReq := httptest.NewRequest("POST", "/api/auth/account/confirm", bytes.NewBuffer(confirmBodyBytes))
confirmReq.Header.Set("Content-Type", "application/json")
confirmRec := httptest.NewRecorder()
confirmRequest := httptest.NewRequest("POST", "/api/auth/account/confirm", bytes.NewBuffer(confirmBodyBytes))
confirmRequest.Header.Set("Content-Type", "application/json")
confirmRecorder := httptest.NewRecorder()
router.ServeHTTP(confirmRec, confirmReq)
router.ServeHTTP(confirmRecorder, confirmRequest)
confirmResponse := assertJSONResponse(t, confirmRec, http.StatusOK)
confirmResponse := assertJSONResponse(t, confirmRecorder, http.StatusOK)
if confirmResponse == nil {
return
}
@@ -209,17 +209,17 @@ func TestIntegration_AccountDeletion(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "del_posts_user", "del_posts@example.com")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Deletion Post", "https://example.com/deletion")
reqBody := map[string]string{}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
requestBody := map[string]string{}
body, _ := json.Marshal(requestBody)
request := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, recorder, http.StatusOK)
if response == nil {
return
}
@@ -237,13 +237,13 @@ func TestIntegration_AccountDeletion(t *testing.T) {
"delete_posts": true,
}
confirmBodyBytes, _ := json.Marshal(confirmBody)
confirmReq := httptest.NewRequest("POST", "/api/auth/account/confirm", bytes.NewBuffer(confirmBodyBytes))
confirmReq.Header.Set("Content-Type", "application/json")
confirmRec := httptest.NewRecorder()
confirmRequest := httptest.NewRequest("POST", "/api/auth/account/confirm", bytes.NewBuffer(confirmBodyBytes))
confirmRequest.Header.Set("Content-Type", "application/json")
confirmRecorder := httptest.NewRecorder()
router.ServeHTTP(confirmRec, confirmReq)
router.ServeHTTP(confirmRecorder, confirmRequest)
confirmResponse := assertJSONResponse(t, confirmRec, http.StatusOK)
confirmResponse := assertJSONResponse(t, confirmRecorder, http.StatusOK)
if confirmResponse == nil {
return
}
@@ -275,12 +275,12 @@ func TestIntegration_MetricsCollection(t *testing.T) {
router := ctx.Router
t.Run("Metrics_Endpoint_Returns_Data", func(t *testing.T) {
req := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/metrics", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, recorder, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["database"]; !exists {
@@ -294,13 +294,13 @@ func TestIntegration_MetricsCollection(t *testing.T) {
ctx.Suite.EmailSender.Reset()
createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "metrics_user", "metrics@example.com")
req := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/metrics", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err == nil {
if err := json.NewDecoder(recorder.Body).Decode(&response); err == nil {
if data, ok := response["data"].(map[string]any); ok {
if dbData, exists := data["database"].(map[string]any); exists {
if _, hasQueries := dbData["total_queries"]; !hasQueries {
@@ -323,7 +323,7 @@ func TestIntegration_ConcurrentRequests(t *testing.T) {
var wg sync.WaitGroup
errors := make(chan error, 10)
for i := 0; i < 10; i++ {
for idx := range 10 {
wg.Add(1)
go func(index int) {
defer wg.Done()
@@ -334,18 +334,18 @@ func TestIntegration_ConcurrentRequests(t *testing.T) {
"content": "Concurrent test content",
}
body, _ := json.Marshal(postBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Code != http.StatusCreated {
errors <- fmt.Errorf("Post %d failed with status %d", index, rec.Code)
if recorder.Code != http.StatusCreated {
errors <- fmt.Errorf("Post %d failed with status %d", index, recorder.Code)
}
}(i)
}(idx)
}
wg.Wait()
@@ -370,28 +370,26 @@ func TestIntegration_ConcurrentRequests(t *testing.T) {
var wg sync.WaitGroup
errors := make(chan error, 5)
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for range 5 {
wg.Go(func() {
voteBody := map[string]string{
"type": "up",
}
body, _ := json.Marshal(voteBody)
req := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
request = testutils.WithURLParams(request, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Code != http.StatusOK {
errors <- fmt.Errorf("Vote failed with status %d", rec.Code)
if recorder.Code != http.StatusOK {
errors <- fmt.Errorf("Vote failed with status %d", recorder.Code)
}
}()
})
}
wg.Wait()
@@ -411,20 +409,18 @@ func TestIntegration_ConcurrentRequests(t *testing.T) {
var wg sync.WaitGroup
errors := make(chan error, 20)
for i := 0; i < 20; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for range 20 {
wg.Go(func() {
req := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/api/posts", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Code != http.StatusOK {
errors <- fmt.Errorf("Read failed with status %d", rec.Code)
if recorder.Code != http.StatusOK {
errors <- fmt.Errorf("Read failed with status %d", recorder.Code)
}
}()
})
}
wg.Wait()

View File

@@ -38,6 +38,10 @@ func CompressionMiddlewareWithConfig(config *CompressionConfig) func(http.Handle
next.ServeHTTP(bufferedWriter, r)
if bufferedWriter.isRedirect {
return
}
if buf.Len() < config.MinSize {
bufferedWriter.flush()
w.Write(buf.Bytes())
@@ -73,9 +77,13 @@ type bufferedResponseWriter struct {
buffer *bytes.Buffer
statusCode int
headerWritten bool
isRedirect bool
}
func (brw *bufferedResponseWriter) Write(b []byte) (int, error) {
if brw.isRedirect {
return brw.ResponseWriter.Write(b)
}
if !brw.headerWritten {
brw.statusCode = http.StatusOK
}
@@ -87,6 +95,11 @@ func (brw *bufferedResponseWriter) WriteHeader(code int) {
return
}
brw.statusCode = code
if isRedirect(code) {
brw.isRedirect = true
brw.ResponseWriter.WriteHeader(code)
brw.headerWritten = true
}
}
func (brw *bufferedResponseWriter) Header() http.Header {
@@ -100,6 +113,10 @@ func (brw *bufferedResponseWriter) flush() {
}
}
func isRedirect(statusCode int) bool {
return statusCode >= 300 && statusCode < 400
}
func shouldCompress(r *http.Request, config *CompressionConfig) bool {
return r.Header.Get("Content-Encoding") == ""
}

View File

@@ -27,7 +27,14 @@ func ValidationMiddleware() func(http.Handler) http.Handler {
dto := reflect.New(dtoType).Interface()
if err := json.NewDecoder(r.Body).Decode(dto); err != nil {
http.Error(w, "Invalid JSON", http.StatusBadRequest)
response := map[string]any{
"success": false,
"error": "Invalid JSON",
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(response)
return
}
@@ -77,3 +84,7 @@ func GetDTOTypeFromContext(ctx context.Context) reflect.Type {
func GetValidatedDTOFromContext(ctx context.Context) any {
return ctx.Value(validatedDTOKey)
}
func SetValidatedDTOInContext(ctx context.Context, dto any) context.Context {
return context.WithValue(ctx, validatedDTOKey, dto)
}

View File

@@ -28,6 +28,7 @@ type UserRepository interface {
Unlock(id uint) error
GetPosts(userID uint, limit, offset int) ([]database.Post, error)
GetDeletedUsers() ([]database.User, error)
GetByUsernamePrefix(prefix string) (*database.User, error)
HardDeleteAll() (int64, error)
Count() (int64, error)
WithTx(tx *gorm.DB) UserRepository
@@ -240,6 +241,17 @@ func (r *userRepository) GetDeletedUsers() ([]database.User, error) {
return users, err
}
func (r *userRepository) GetByUsernamePrefix(prefix string) (*database.User, error) {
var user database.User
err := r.db.
Where("username LIKE ? AND email LIKE ?", prefix+"%", prefix+"%@goyco.local").
First(&user).Error
if err != nil {
return nil, err
}
return &user, nil
}
func (r *userRepository) HardDeleteAll() (int64, error) {
var totalDeleted int64
err := r.db.Transaction(func(tx *gorm.DB) error {

View File

@@ -20,6 +20,7 @@ type VoteRepository interface {
Count() (int64, error)
CountByPostID(postID uint) (int64, error)
CountByUserID(userID uint) (int64, error)
GetVoteCountsByPostID(postID uint) (upVotes int, downVotes int, err error)
WithTx(tx *gorm.DB) VoteRepository
}
@@ -144,3 +145,20 @@ func (r *voteRepository) Count() (int64, error) {
err := r.db.Model(&database.Vote{}).Count(&count).Error
return count, err
}
func (r *voteRepository) GetVoteCountsByPostID(postID uint) (int, int, error) {
var result struct {
UpVotes int64
DownVotes int64
}
err := r.db.Model(&database.Vote{}).
Select("COUNT(CASE WHEN type = ? THEN 1 END) as up_votes, COUNT(CASE WHEN type = ? THEN 1 END) as down_votes", database.VoteUp, database.VoteDown).
Where("post_id = ?", postID).
Scan(&result).Error
if err != nil {
return 0, 0, err
}
return int(result.UpVotes), int(result.DownVotes), nil
}

View File

@@ -1,8 +1,10 @@
package server
import (
"mime"
"net/http"
"path/filepath"
"strings"
"time"
"goyco/internal/config"
@@ -71,8 +73,9 @@ func NewRouter(cfg RouterConfig) http.Handler {
AuthRateLimit: func(r chi.Router) chi.Router {
return r.With(middleware.AuthRateLimitMiddlewareWithLimit(cfg.RateLimitConfig.AuthLimit))
},
CSRFMiddleware: middleware.CSRFMiddleware(),
AuthMiddleware: middleware.NewAuth(cfg.AuthService),
CSRFMiddleware: middleware.CSRFMiddleware(),
AuthMiddleware: middleware.NewAuth(cfg.AuthService),
ValidationMiddleware: middleware.ValidationMiddleware(),
}
if cfg.PageHandler != nil {
@@ -123,7 +126,33 @@ func NewRouter(cfg RouterConfig) http.Handler {
staticDir = "./internal/static/"
}
router.Handle("/static/*", http.StripPrefix("/static/", http.FileServer(http.Dir(staticDir))))
staticFileServer := http.StripPrefix("/static/", http.FileServer(http.Dir(staticDir)))
router.Handle("/static/*", staticFileHandler(staticFileServer))
return router
}
func staticFileHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
ext := filepath.Ext(path)
if ext == ".css" {
w.Header().Set("Content-Type", "text/css; charset=utf-8")
} else if ext == ".js" {
w.Header().Set("Content-Type", "application/javascript; charset=utf-8")
} else if ext == ".json" {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
} else if ext == ".ico" {
w.Header().Set("Content-Type", "image/x-icon")
} else if strings.HasPrefix(mime.TypeByExtension(ext), "image/") {
w.Header().Set("Content-Type", mime.TypeByExtension(ext))
} else if strings.HasPrefix(mime.TypeByExtension(ext), "font/") {
w.Header().Set("Content-Type", mime.TypeByExtension(ext))
} else if mimeType := mime.TypeByExtension(ext); mimeType != "" {
w.Header().Set("Content-Type", mimeType)
}
next.ServeHTTP(w, r)
})
}

View File

@@ -3,6 +3,7 @@ package server
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"goyco/internal/config"
@@ -105,9 +106,9 @@ func defaultRateLimitConfig() config.RateLimitConfig {
return testutils.AppTestConfig.RateLimit
}
func TestAPIRootRouting(t *testing.T) {
func createDefaultRouterConfig() RouterConfig {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
return RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
@@ -115,7 +116,15 @@ func TestAPIRootRouting(t *testing.T) {
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
}
}
func createTestRouter(cfg RouterConfig) http.Handler {
return NewRouter(cfg)
}
func TestAPIRootRouting(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
@@ -141,23 +150,23 @@ func TestAPIRootRouting(t *testing.T) {
}
func TestProtectedRoutesRequireAuth(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
router := createTestRouter(createDefaultRouterConfig())
protectedRoutes := []struct {
method string
path string
}{
{http.MethodGet, "/api/auth/me"},
{http.MethodPost, "/api/auth/logout"},
{http.MethodPost, "/api/auth/revoke"},
{http.MethodPost, "/api/auth/revoke-all"},
{http.MethodPut, "/api/auth/email"},
{http.MethodPut, "/api/auth/username"},
{http.MethodPut, "/api/auth/password"},
{http.MethodDelete, "/api/auth/account"},
{http.MethodPost, "/api/posts"},
{http.MethodPut, "/api/posts/1"},
{http.MethodDelete, "/api/posts/1"},
{http.MethodPost, "/api/posts/1/vote"},
{http.MethodDelete, "/api/posts/1/vote"},
{http.MethodGet, "/api/posts/1/vote"},
@@ -183,17 +192,9 @@ func TestProtectedRoutesRequireAuth(t *testing.T) {
}
func TestRouterWithDebugMode(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
Debug: true,
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.Debug = true
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -206,16 +207,9 @@ func TestRouterWithDebugMode(t *testing.T) {
}
func TestRouterWithCacheDisabled(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
DisableCache: true,
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
})
cfg := createDefaultRouterConfig()
cfg.DisableCache = true
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -228,17 +222,9 @@ func TestRouterWithCacheDisabled(t *testing.T) {
}
func TestRouterWithCompressionDisabled(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
DisableCompression: true,
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.DisableCompression = true
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -251,19 +237,9 @@ func TestRouterWithCompressionDisabled(t *testing.T) {
}
func TestRouterWithCustomDBMonitor(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
customDBMonitor := middleware.NewInMemoryDBMonitor()
router := NewRouter(RouterConfig{
DBMonitor: customDBMonitor,
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.DBMonitor = middleware.NewInMemoryDBMonitor()
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -296,18 +272,9 @@ func TestRouterWithPageHandler(t *testing.T) {
}
func TestRouterWithStaticDir(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
StaticDir: "/custom/static/path",
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.StaticDir = "/custom/static/path"
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -320,18 +287,9 @@ func TestRouterWithStaticDir(t *testing.T) {
}
func TestRouterWithEmptyStaticDir(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
StaticDir: "",
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.StaticDir = ""
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -344,20 +302,11 @@ func TestRouterWithEmptyStaticDir(t *testing.T) {
}
func TestRouterWithAllFeaturesDisabled(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
Debug: true,
DisableCache: true,
DisableCompression: true,
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.Debug = true
cfg.DisableCache = true
cfg.DisableCompression = true
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -370,15 +319,9 @@ func TestRouterWithAllFeaturesDisabled(t *testing.T) {
}
func TestRouterWithoutAPIHandler(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, _, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.APIHandler = nil
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -391,17 +334,7 @@ func TestRouterWithoutAPIHandler(t *testing.T) {
}
func TestRouterWithoutPageHandler(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
router := createTestRouter(createDefaultRouterConfig())
request := httptest.NewRequest(http.MethodGet, "/", nil)
recorder := httptest.NewRecorder()
@@ -414,17 +347,7 @@ func TestRouterWithoutPageHandler(t *testing.T) {
}
func TestSwaggerRoute(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
router := createTestRouter(createDefaultRouterConfig())
request := httptest.NewRequest(http.MethodGet, "/swagger/", nil)
recorder := httptest.NewRecorder()
@@ -437,18 +360,9 @@ func TestSwaggerRoute(t *testing.T) {
}
func TestStaticFileRoute(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
StaticDir: "../../internal/static/",
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.StaticDir = "../../internal/static/"
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/static/css/main.css", nil)
recorder := httptest.NewRecorder()
@@ -461,17 +375,7 @@ func TestStaticFileRoute(t *testing.T) {
}
func TestRouterConfiguration(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
router := createTestRouter(createDefaultRouterConfig())
if router == nil {
t.Error("Router should not be nil")
@@ -487,29 +391,484 @@ func TestRouterConfiguration(t *testing.T) {
}
}
func TestRouterMiddlewareIntegration(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
func TestAllRoutesExist(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
if router == nil {
t.Error("Router should not be nil")
publicRoutes := []struct {
method string
path string
description string
}{
{http.MethodGet, "/api", "API info"},
{http.MethodGet, "/health", "Health check"},
{http.MethodGet, "/metrics", "Metrics"},
{http.MethodGet, "/robots.txt", "Robots.txt"},
{http.MethodGet, "/api/posts", "Get posts"},
{http.MethodGet, "/api/posts/search", "Search posts"},
{http.MethodGet, "/api/posts/title", "Fetch title from URL"},
{http.MethodGet, "/api/posts/1", "Get post by ID"},
{http.MethodPost, "/api/auth/register", "Register"},
{http.MethodPost, "/api/auth/login", "Login"},
{http.MethodPost, "/api/auth/refresh", "Refresh token"},
{http.MethodGet, "/api/auth/confirm", "Confirm email"},
{http.MethodPost, "/api/auth/resend-verification", "Resend verification"},
{http.MethodPost, "/api/auth/forgot-password", "Forgot password"},
{http.MethodPost, "/api/auth/reset-password", "Reset password"},
{http.MethodPost, "/api/auth/account/confirm", "Confirm account deletion"},
}
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
protectedRoutes := []struct {
method string
path string
description string
}{
{http.MethodGet, "/api/auth/me", "Get current user"},
{http.MethodPost, "/api/auth/logout", "Logout"},
{http.MethodPost, "/api/auth/revoke", "Revoke token"},
{http.MethodPost, "/api/auth/revoke-all", "Revoke all tokens"},
{http.MethodPut, "/api/auth/email", "Update email"},
{http.MethodPut, "/api/auth/username", "Update username"},
{http.MethodPut, "/api/auth/password", "Update password"},
{http.MethodDelete, "/api/auth/account", "Delete account"},
{http.MethodPost, "/api/posts", "Create post"},
{http.MethodPut, "/api/posts/1", "Update post"},
{http.MethodDelete, "/api/posts/1", "Delete post"},
{http.MethodPost, "/api/posts/1/vote", "Cast vote"},
{http.MethodDelete, "/api/posts/1/vote", "Remove vote"},
{http.MethodGet, "/api/posts/1/vote", "Get user vote"},
{http.MethodGet, "/api/posts/1/votes", "Get post votes"},
{http.MethodGet, "/api/users", "Get users"},
{http.MethodPost, "/api/users", "Create user"},
{http.MethodGet, "/api/users/1", "Get user by ID"},
{http.MethodGet, "/api/users/1/posts", "Get user posts"},
}
router.ServeHTTP(recorder, request)
for _, route := range publicRoutes {
t.Run(route.description+" "+route.method+" "+route.path, func(t *testing.T) {
invalidMethod := http.MethodPatch
switch route.method {
case http.MethodGet:
invalidMethod = http.MethodDelete
case http.MethodPost:
invalidMethod = http.MethodGet
}
request := httptest.NewRequest(invalidMethod, route.path, nil)
recorder := httptest.NewRecorder()
if recorder.Code == 0 {
t.Error("Router should return a status code")
router.ServeHTTP(recorder, request)
routeExists := recorder.Code == http.StatusMethodNotAllowed
if !routeExists {
request = httptest.NewRequest(route.method, route.path, nil)
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code == http.StatusNotFound && route.path != "/api/posts/1" && route.path != "/robots.txt" {
t.Errorf("Route %s %s should exist, got 404", route.method, route.path)
}
}
})
}
for _, route := range protectedRoutes {
t.Run(route.description+" "+route.method+" "+route.path, func(t *testing.T) {
request := httptest.NewRequest(route.method, route.path, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code == http.StatusNotFound {
t.Errorf("Route %s %s should exist, got 404", route.method, route.path)
}
if recorder.Code != http.StatusUnauthorized {
t.Errorf("Protected route %s %s should return 401 without auth, got %d", route.method, route.path, recorder.Code)
}
})
}
}
func TestRouteParameters(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
method string
pathPattern string
testIDs []string
isProtected bool
}{
{
name: "Get post by ID",
method: http.MethodGet,
pathPattern: "/api/posts/{id}",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: false,
},
{
name: "Update post by ID",
method: http.MethodPut,
pathPattern: "/api/posts/{id}",
testIDs: []string{"1", "42", "999"},
isProtected: true,
},
{
name: "Delete post by ID",
method: http.MethodDelete,
pathPattern: "/api/posts/{id}",
testIDs: []string{"1", "42", "999"},
isProtected: true,
},
{
name: "Get user by ID",
method: http.MethodGet,
pathPattern: "/api/users/{id}",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: true,
},
{
name: "Get user posts by user ID",
method: http.MethodGet,
pathPattern: "/api/users/{id}/posts",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: true,
},
{
name: "Cast vote for post ID",
method: http.MethodPost,
pathPattern: "/api/posts/{id}/vote",
testIDs: []string{"1", "42", "999"},
isProtected: true,
},
{
name: "Remove vote for post ID",
method: http.MethodDelete,
pathPattern: "/api/posts/{id}/vote",
testIDs: []string{"1", "42", "999"},
isProtected: true,
},
{
name: "Get user vote for post ID",
method: http.MethodGet,
pathPattern: "/api/posts/{id}/vote",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: true,
},
{
name: "Get post votes by post ID",
method: http.MethodGet,
pathPattern: "/api/posts/{id}/votes",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, id := range tc.testIDs {
path := replaceID(tc.pathPattern, id)
t.Run("ID_"+id, func(t *testing.T) {
request := httptest.NewRequest(http.MethodPatch, path, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
routeExists := recorder.Code == http.StatusMethodNotAllowed
request = httptest.NewRequest(tc.method, path, nil)
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if !routeExists {
if recorder.Code == http.StatusNotFound {
t.Errorf("Route %s %s should exist with ID %s, got 404", tc.method, path, id)
return
}
}
if tc.isProtected {
if recorder.Code != http.StatusUnauthorized {
t.Errorf("Protected route %s %s should return 401 without auth, got %d", tc.method, path, recorder.Code)
}
}
})
}
})
}
}
func replaceID(pattern, id string) string {
return strings.Replace(pattern, "{id}", id, 1)
}
func TestInvalidRouteParameters(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
method string
path string
expectedMin int
expectedMax int
isProtected bool
allow401 bool
}{
{
name: "Non-numeric post ID",
method: http.MethodGet,
path: "/api/posts/abc",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusBadRequest,
isProtected: false,
},
{
name: "Negative post ID",
method: http.MethodGet,
path: "/api/posts/-1",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusBadRequest,
isProtected: false,
},
{
name: "Zero post ID",
method: http.MethodGet,
path: "/api/posts/0",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusNotFound,
isProtected: false,
},
{
name: "Post ID with special characters",
method: http.MethodGet,
path: "/api/posts/123@456",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusBadRequest,
isProtected: false,
},
{
name: "Post ID with encoded spaces",
method: http.MethodGet,
path: "/api/posts/12%2034",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusBadRequest,
isProtected: false,
},
{
name: "Non-numeric user ID",
method: http.MethodGet,
path: "/api/users/xyz",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusUnauthorized,
isProtected: true,
allow401: true,
},
{
name: "Negative user ID",
method: http.MethodGet,
path: "/api/users/-5",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusUnauthorized,
isProtected: true,
allow401: true,
},
{
name: "Non-numeric post ID in vote route",
method: http.MethodGet,
path: "/api/posts/invalid/vote",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusUnauthorized,
isProtected: true,
allow401: true,
},
{
name: "Very large post ID",
method: http.MethodGet,
path: "/api/posts/999999999999",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusNotFound,
isProtected: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
request := httptest.NewRequest(tc.method, tc.path, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if tc.isProtected && tc.allow401 {
if recorder.Code != http.StatusUnauthorized && (recorder.Code < tc.expectedMin || recorder.Code > tc.expectedMax) {
t.Errorf("Protected route %s %s with invalid parameter should return 401 or status between %d and %d, got %d", tc.method, tc.path, tc.expectedMin, tc.expectedMax, recorder.Code)
}
} else {
if recorder.Code < tc.expectedMin || recorder.Code > tc.expectedMax {
t.Errorf("Route %s %s should return status between %d and %d, got %d", tc.method, tc.path, tc.expectedMin, tc.expectedMax, recorder.Code)
}
if recorder.Code != http.StatusNotFound && recorder.Code < 400 {
t.Errorf("Route %s %s with invalid parameter should return error status (4xx), got %d", tc.method, tc.path, recorder.Code)
}
}
})
}
}
func TestQueryParameters(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
method string
path string
queryParams string
expectRoute bool
}{
{
name: "Get posts with limit and offset",
method: http.MethodGet,
path: "/api/posts",
queryParams: "limit=10&offset=5",
expectRoute: true,
},
{
name: "Get posts with only limit",
method: http.MethodGet,
path: "/api/posts",
queryParams: "limit=20",
expectRoute: true,
},
{
name: "Get posts with only offset",
method: http.MethodGet,
path: "/api/posts",
queryParams: "offset=10",
expectRoute: true,
},
{
name: "Search posts with query parameter",
method: http.MethodGet,
path: "/api/posts/search",
queryParams: "q=test",
expectRoute: true,
},
{
name: "Search posts with query, limit, and offset",
method: http.MethodGet,
path: "/api/posts/search",
queryParams: "q=test&limit=15&offset=3",
expectRoute: true,
},
{
name: "Fetch title with URL parameter",
method: http.MethodGet,
path: "/api/posts/title",
queryParams: "url=https://example.com",
expectRoute: true,
},
{
name: "Confirm email with token parameter",
method: http.MethodGet,
path: "/api/auth/confirm",
queryParams: "token=abc123",
expectRoute: true,
},
{
name: "Get posts with invalid limit",
method: http.MethodGet,
path: "/api/posts",
queryParams: "limit=abc",
expectRoute: true,
},
{
name: "Get posts with negative limit",
method: http.MethodGet,
path: "/api/posts",
queryParams: "limit=-5",
expectRoute: true,
},
{
name: "Get posts with negative offset",
method: http.MethodGet,
path: "/api/posts",
queryParams: "offset=-10",
expectRoute: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
fullPath := tc.path
if tc.queryParams != "" {
fullPath += "?" + tc.queryParams
}
request := httptest.NewRequest(tc.method, fullPath, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if tc.expectRoute {
if recorder.Code == http.StatusNotFound {
t.Errorf("Route %s %s should exist with query parameters, got 404", tc.method, fullPath)
}
}
})
}
}
func TestRouteConflicts(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
method string
path string
description string
}{
{
name: "posts/search should not match posts/{id}",
method: http.MethodGet,
path: "/api/posts/search",
description: "search route should be matched, not treated as ID",
},
{
name: "posts/title should not match posts/{id}",
method: http.MethodGet,
path: "/api/posts/title",
description: "title route should be matched, not treated as ID",
},
{
name: "posts/{id} should work with numeric ID",
method: http.MethodGet,
path: "/api/posts/123",
description: "numeric ID should match {id} route",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
request := httptest.NewRequest(tc.method, tc.path, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
switch tc.path {
case "/api/posts/search":
if recorder.Code == http.StatusNotFound {
t.Errorf("%s: Route %s %s should exist (not 404), got %d", tc.description, tc.method, tc.path, recorder.Code)
}
case "/api/posts/title":
if recorder.Code == http.StatusNotFound {
t.Errorf("%s: Route %s %s should exist (not 404), got %d", tc.description, tc.method, tc.path, recorder.Code)
}
case "/api/posts/123":
if recorder.Code == http.StatusNotFound {
return
}
if recorder.Code < 400 {
t.Errorf("%s: Route %s %s should return 4xx or 5xx, got %d", tc.description, tc.method, tc.path, recorder.Code)
}
}
})
}
}

View File

@@ -1,12 +1,93 @@
package services
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"net/mail"
"strings"
"goyco/internal/config"
"goyco/internal/database"
)
const (
defaultTokenExpirationHours = 24
verificationTokenBytes = 32
deletionTokenExpirationHours = 24
)
var (
ErrInvalidCredentials = errors.New("invalid credentials")
ErrInvalidToken = errors.New("invalid or expired token")
ErrUsernameTaken = errors.New("username already exists")
ErrEmailTaken = errors.New("email already exists")
ErrInvalidEmail = errors.New("invalid email address")
ErrPasswordTooShort = errors.New("password too short")
ErrEmailNotVerified = errors.New("email not verified")
ErrAccountLocked = errors.New("account is locked")
ErrInvalidVerificationToken = errors.New("invalid verification token")
ErrEmailSenderUnavailable = errors.New("email sender not configured")
ErrDeletionEmailFailed = errors.New("account deletion email failed")
ErrInvalidDeletionToken = errors.New("invalid account deletion token")
ErrUserNotFound = errors.New("user not found")
ErrDeletionRequestNotFound = errors.New("deletion request not found")
)
type AuthResult struct {
AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
RefreshToken string `json:"refresh_token" example:"a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6q7r8s9t0"`
User *database.User `json:"user"`
}
type RegistrationResult struct {
User *database.User `json:"user"`
VerificationSent bool `json:"verification_sent"`
}
func normalizeEmail(email string) (string, error) {
trimmed := strings.TrimSpace(email)
if trimmed == "" {
return "", fmt.Errorf("email is required")
}
parsed, err := mail.ParseAddress(trimmed)
if err != nil {
return "", ErrInvalidEmail
}
return strings.ToLower(parsed.Address), nil
}
func generateVerificationToken() (string, string, error) {
buf := make([]byte, verificationTokenBytes)
if _, err := rand.Read(buf); err != nil {
return "", "", fmt.Errorf("generate verification token: %w", err)
}
token := hex.EncodeToString(buf)
hashed := HashVerificationToken(token)
return token, hashed, nil
}
func HashVerificationToken(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}
func sanitizeUser(user *database.User) *database.User {
if user == nil {
return nil
}
copy := *user
copy.Password = ""
copy.EmailVerificationToken = ""
return &copy
}
type AuthFacade struct {
registrationService *RegistrationService
passwordResetService *PasswordResetService

View File

@@ -1,35 +0,0 @@
package services
import (
"errors"
"goyco/internal/database"
)
var (
ErrInvalidCredentials = errors.New("invalid credentials")
ErrInvalidToken = errors.New("invalid or expired token")
ErrUsernameTaken = errors.New("username already exists")
ErrEmailTaken = errors.New("email already exists")
ErrInvalidEmail = errors.New("invalid email address")
ErrPasswordTooShort = errors.New("password too short")
ErrEmailNotVerified = errors.New("email not verified")
ErrAccountLocked = errors.New("account is locked")
ErrInvalidVerificationToken = errors.New("invalid verification token")
ErrEmailSenderUnavailable = errors.New("email sender not configured")
ErrDeletionEmailFailed = errors.New("account deletion email failed")
ErrInvalidDeletionToken = errors.New("invalid account deletion token")
ErrUserNotFound = errors.New("user not found")
ErrDeletionRequestNotFound = errors.New("deletion request not found")
)
type AuthResult struct {
AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
RefreshToken string `json:"refresh_token" example:"a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6q7r8s9t0"`
User *database.User `json:"user"`
}
type RegistrationResult struct {
User *database.User `json:"user"`
VerificationSent bool `json:"verification_sent"`
}

View File

@@ -1,59 +0,0 @@
package services
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/mail"
"strings"
"goyco/internal/database"
)
const (
defaultTokenExpirationHours = 24
verificationTokenBytes = 32
deletionTokenExpirationHours = 24
)
func normalizeEmail(email string) (string, error) {
trimmed := strings.TrimSpace(email)
if trimmed == "" {
return "", fmt.Errorf("email is required")
}
parsed, err := mail.ParseAddress(trimmed)
if err != nil {
return "", ErrInvalidEmail
}
return strings.ToLower(parsed.Address), nil
}
func generateVerificationToken() (string, string, error) {
buf := make([]byte, verificationTokenBytes)
if _, err := rand.Read(buf); err != nil {
return "", "", fmt.Errorf("generate verification token: %w", err)
}
token := hex.EncodeToString(buf)
hashed := HashVerificationToken(token)
return token, hashed, nil
}
func HashVerificationToken(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}
func sanitizeUser(user *database.User) *database.User {
if user == nil {
return nil
}
copy := *user
copy.Password = ""
copy.EmailVerificationToken = ""
return &copy
}

View File

@@ -7,9 +7,10 @@ import (
"sync"
"testing"
"gorm.io/gorm"
"goyco/internal/database"
"goyco/internal/repositories"
"gorm.io/gorm"
)
type mockVoteRepo struct {
@@ -240,6 +241,25 @@ func (m *mockVoteRepo) CountByUserID(userID uint) (int64, error) {
return count, nil
}
func (m *mockVoteRepo) GetVoteCountsByPostID(postID uint) (int, int, error) {
m.mu.RLock()
defer m.mu.RUnlock()
upVotes := 0
downVotes := 0
for _, vote := range m.votes {
if vote.PostID == postID {
switch vote.Type {
case database.VoteUp:
upVotes++
case database.VoteDown:
downVotes++
}
}
}
return upVotes, downVotes, nil
}
func (m *mockVoteRepo) WithTx(tx *gorm.DB) repositories.VoteRepository {
return m
}

View File

@@ -2,43 +2,70 @@ package templates
import (
"html/template"
"io/fs"
"path/filepath"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestTemplateParsing(t *testing.T) {
templateDir := "./"
func templateFuncMap() template.FuncMap {
return template.FuncMap{
"formatTime": func(t time.Time) string {
if t.IsZero() {
return ""
}
return t.Format("02 Jan 2006 15:04")
},
"truncate": func(s string, length int) string {
if len(s) <= length {
return s
}
if length <= 3 {
return s[:length]
}
return s[:length-3] + "..."
},
"substr": func(s string, start, length int) string {
if start >= len(s) {
return ""
}
end := min(start+length, len(s))
return s[start:end]
},
"upper": strings.ToUpper,
}
}
var templateFiles []string
err := filepath.WalkDir(templateDir, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if !d.IsDir() && filepath.Ext(path) == ".gohtml" {
templateFiles = append(templateFiles, path)
}
return nil
})
func TestTemplateParsing(t *testing.T) {
layoutPath := filepath.Join(".", "base.gohtml")
require.FileExists(t, layoutPath, "base layout is required for all templates")
partials, err := filepath.Glob(filepath.Join(".", "partials", "*.gohtml"))
require.NoError(t, err)
tmpl := template.New("test")
pages, err := filepath.Glob(filepath.Join(".", "*.gohtml"))
require.NoError(t, err)
require.NotEmpty(t, pages, "no page templates found")
tmpl = tmpl.Funcs(template.FuncMap{
"formatTime": func(any) string { return "2024-01-01" },
"eq": func(a, b any) bool { return a == b },
"ne": func(a, b any) bool { return a != b },
"len": func(s any) int { return 0 },
"range": func(s any) any { return s },
})
for _, page := range pages {
if filepath.Base(page) == "base.gohtml" {
continue
}
for _, file := range templateFiles {
t.Run(file, func(t *testing.T) {
_, err := tmpl.ParseFiles(file)
assert.NoError(t, err, "Template %s should parse without errors", file)
page := page
t.Run(filepath.Base(page), func(t *testing.T) {
t.Parallel()
files := append([]string{layoutPath}, partials...)
files = append(files, page)
tmpl, err := template.New(filepath.Base(page)).Funcs(templateFuncMap()).ParseFiles(files...)
require.NoError(t, err)
require.NotNil(t, tmpl.Lookup("layout"), "layout template should be available")
require.NotNil(t, tmpl.Lookup("content"), "content block should be defined by page templates")
})
}
}

View File

@@ -422,6 +422,24 @@ func (m *MockUserRepository) GetDeletedUsers() ([]database.User, error) {
return []database.User{}, nil
}
func (m *MockUserRepository) GetByUsernamePrefix(prefix string) (*database.User, error) {
m.mu.RLock()
defer m.mu.RUnlock()
for _, user := range m.users {
if len(user.Username) >= len(prefix) && user.Username[:len(prefix)] == prefix {
if len(user.Email) >= 13 && strings.HasSuffix(user.Email, "@goyco.local") {
emailPrefix := user.Email[:len(user.Email)-13]
if len(emailPrefix) >= len(prefix) && emailPrefix[:len(prefix)] == prefix {
return user, nil
}
}
}
}
return nil, gorm.ErrRecordNotFound
}
func (m *MockUserRepository) HardDeleteAll() (int64, error) {
if m.HardDeleteAllFunc != nil {
return m.HardDeleteAllFunc()
@@ -965,6 +983,25 @@ func (m *MockVoteRepository) CountByUserID(userID uint) (int64, error) {
return count, nil
}
func (m *MockVoteRepository) GetVoteCountsByPostID(postID uint) (int, int, error) {
m.mu.RLock()
defer m.mu.RUnlock()
upVotes := 0
downVotes := 0
for _, vote := range m.votes {
if vote.PostID == postID {
switch vote.Type {
case database.VoteUp:
upVotes++
case database.VoteDown:
downVotes++
}
}
}
return upVotes, downVotes, nil
}
func (m *MockVoteRepository) Count() (int64, error) {
m.mu.RLock()
defer m.mu.RUnlock()

View File

@@ -153,6 +153,7 @@ type UserRepositoryStub struct {
UnlockFn func(uint) error
GetPostsFn func(uint, int, int) ([]database.Post, error)
GetDeletedUsersFn func() ([]database.User, error)
GetByUsernamePrefixFn func(string) (*database.User, error)
HardDeleteAllFn func() (int64, error)
CountFn func() (int64, error)
WithTxFn func(*gorm.DB) repositories.UserRepository
@@ -281,6 +282,13 @@ func (s *UserRepositoryStub) GetDeletedUsers() ([]database.User, error) {
return nil, nil
}
func (s *UserRepositoryStub) GetByUsernamePrefix(prefix string) (*database.User, error) {
if s != nil && s.GetByUsernamePrefixFn != nil {
return s.GetByUsernamePrefixFn(prefix)
}
return nil, gorm.ErrRecordNotFound
}
func (s *UserRepositoryStub) HardDeleteAll() (int64, error) {
if s != nil && s.HardDeleteAllFn != nil {
return s.HardDeleteAllFn()

View File

@@ -30,7 +30,7 @@ if [ ! -d "docs" ]; then
mkdir -p docs
fi
SWAGGER_DIRECTORIES="cmd/goyco,internal/handlers"
SWAGGER_DIRECTORIES="cmd/goyco,internal/handlers,internal/dto"
SWAGGER_MAIN_FILE="main.go"
SWAGGER_OUTPUT_DIR="docs"

View File

@@ -2,21 +2,21 @@
# helper script to setup a postgres database on deb based systems
if [ "$EUID" -ne 0 ]; then
echo "Please run as root"
exit 1
echo "Please run as root"
exit 1
fi
read -s "Do you want to install PostgreSQL 18? [y/N] " INSTALL_PG
if [ "$INSTALL_PG" != "y" ]; then
echo "PostgreSQL 18 will not be installed"
exit 0
echo "PostgreSQL 18 will not be installed"
exit 0
fi
read -s -p "Enter password for PostgreSQL user 'goyco': " GOYCO_PWD
echo
apt-get update
apt-get install -y postgresql-18
apt-get install -y postgresql-18
systemctl enable --now postgresql
@@ -44,5 +44,3 @@ GRANT ALL PRIVILEGES ON DATABASE goyco TO goyco;
EOF
echo "PostgreSQL 18 installed, database 'goyco' and user 'goyco' set up."