Compare commits

...

130 Commits

Author SHA1 Message Date
817205d42f refactor: modernize using min() 2025-12-16 15:45:51 +01:00
199ac143a4 refactor: replace interface{} by any 2025-12-16 15:05:23 +01:00
aa7e259ed0 format: shfmt 2025-12-16 15:02:42 +01:00
4587609e17 refactor: create createTestRouter and test edge cases 2025-12-14 21:14:42 +01:00
33da6503e3 test: also test put/delete routes 2025-12-14 21:06:15 +01:00
cafc44ed77 test: add a test for route parameters 2025-12-14 21:04:36 +01:00
1480135e75 test: verified all routes to exist 2025-12-14 21:02:25 +01:00
02a764c736 clean: remove merged files 2025-12-14 20:52:14 +01:00
6834ad7764 refactor: merge facade, types and utils into one auth_service.go 2025-12-14 20:52:03 +01:00
dcf054046f test: fix parallel processor test expectations and setup 2025-12-10 07:30:27 +01:00
d2a788933d fix: track completed items in the main loop instead of using the index 2025-12-10 07:29:55 +01:00
18be3950dc clean: obsolete function 2025-12-09 22:07:30 +01:00
f9cb140e95 clean: removed the obsolete functions outputMessage and outputError 2025-12-09 22:06:12 +01:00
86d4835ccf feat: seed user is now uniq 2025-12-09 22:03:26 +01:00
feddb2ed43 test: new unit test for EnsureSeedUser 2025-12-09 22:03:16 +01:00
457b5c88e2 refactor: improve seed consistency validation 2025-12-09 21:53:03 +01:00
a8d363b2bf fix: templates now parse with the same func map as the page handler 2025-12-09 21:37:21 +01:00
0cd68e847c refactor: add a helper to centralize CSRF token retrieval 2025-12-09 15:58:28 +01:00
df6aeed713 docs: unocss it is 2025-12-04 20:43:30 +01:00
785faeb60c feat: update alpine to 3.23 2025-12-04 10:01:19 +01:00
0623c027ba docs: prepare CONTRIBUTING.md 2025-12-03 20:57:32 +01:00
d4e91b6034 refactor: complete refactor and better helpers use 2025-11-29 15:19:41 +01:00
7d46d3e81b clean: remove the unused expectedValue in assertHeader (always set to "") 2025-11-29 15:19:28 +01:00
216aaf3117 refactor: clean code and use new request helpers 2025-11-29 14:58:52 +01:00
435047ad0c refactor: clean code 2025-11-29 14:58:37 +01:00
b7ee8bd11d refactor: clean variable names and use new request helpers 2025-11-29 14:58:20 +01:00
040cd48be8 refactor: clean variables 2025-11-29 14:58:07 +01:00
2dd16e0e00 refactor: complete 2025-11-29 14:56:18 +01:00
d6db70cc79 refactor: clean code and variables, use new request helpers 2025-11-29 14:55:47 +01:00
58e10ade7d refactor: clean variable names and modernize code 2025-11-29 14:50:35 +01:00
7403a75d8e refactor: clean variable naming 2025-11-29 14:46:26 +01:00
b429bc11af refactor: clean code and use new request helpers 2025-11-29 14:41:38 +01:00
2ec5c28fb5 refactor: rename variables and clean code 2025-11-29 14:37:18 +01:00
3743a99e40 refactor: req -> request, rec -> recorder, reqBody -> requestBody... 2025-11-29 14:21:07 +01:00
5710921b87 refactor: use new request helpers 2025-11-29 14:17:25 +01:00
84d9c81484 refactor: rec -> recorder, req -> request and modernize loop 2025-11-29 14:15:07 +01:00
b0c2038927 feat: add new helpers to make requests properly in integration tests 2025-11-29 14:11:32 +01:00
fd88931146 refactor: variable names and modernize loop 2025-11-25 14:37:59 +01:00
6ce0f4dfad refactor: name variables 2025-11-25 10:13:10 +01:00
68e3dceefc refactor: name variables 2025-11-25 10:08:48 +01:00
cded14c526 fix: force correct mime types for static files after modifying compression middleware's buffering 2025-11-24 07:53:08 +01:00
cfca668ca6 docs: clean readme 2025-11-23 21:56:53 +01:00
279255b587 fix: don't let rate limiting fails the test 2025-11-23 21:48:11 +01:00
b83f8c2228 fix: update ValidationMiddleware to return a JSON error response when JSON decoding fails 2025-11-23 21:42:27 +01:00
aabc48128c fix: use router in handlers integration tests (for dto validation) 2025-11-23 15:10:55 +01:00
68716b977b fix: verify XSS sanitization in handler response instead of repository stub 2025-11-23 15:01:54 +01:00
dbe1600632 fix: indentation 2025-11-23 14:49:31 +01:00
458e25cf79 fix: modify compression middleware to pass through redirects immediately without buffering 2025-11-23 14:48:59 +01:00
d4595d8dbf fix: properly encoding the flash message in the redirect URL 2025-11-23 14:48:39 +01:00
c5418f4e4c docs: update swagger 2025-11-23 14:26:52 +01:00
db0369225e refactor: update references to VoteRequest 2025-11-23 14:26:45 +01:00
07ac965b3d refactor: use consistent naming (VoteRequest -> CastVoteRequest) 2025-11-23 14:26:19 +01:00
e2e5d42035 feat: add SetValidatedDTOInContext to support test helper functions 2025-11-23 14:22:59 +01:00
6e4b41894f fix: update test cases to use createCreatePostRequests 2025-11-23 14:22:35 +01:00
0a8ed2e27c fix: add explicite validation check for empty url, title and content length 2025-11-23 14:21:30 +01:00
216e8657f6 feat: add generic createRequestWithDTO along with helpers functions 2025-11-23 14:20:59 +01:00
fb7206c0a2 fix: test context handling 2025-11-23 14:20:09 +01:00
c25926514b fix: add explicit empty-field validation check in handlers 2025-11-23 14:19:54 +01:00
964785e494 docs: update swagger 2025-11-23 13:47:38 +01:00
9c67cd2a47 feat: update vote handler to use dto VoteRequest and update MountRoutes 2025-11-23 13:47:31 +01:00
8b5cc8e939 feat: add VoteRequest with its validation types 2025-11-23 13:46:51 +01:00
0e71b28615 feat: update CreateUser to use dto.RegisterRequest and update MountRoutes to apply validation middleware 2025-11-23 13:43:47 +01:00
cd740da57a feat: update methods to use validated DTOs and update MountRoutes 2025-11-23 13:43:14 +01:00
abe4a3dc88 feat: update handlers to use GetValidatedDTO instead of manual decoding and update MountRoutes to wrap handlers with WithValidation for all DTO-based routes 2025-11-23 13:42:52 +01:00
738243d945 feat: add ValidationMiddleware to RouteModuleConfig 2025-11-23 13:41:55 +01:00
4fbdfb6e4a feat: add two helpers function to retrieve validated DTOs from request context and to apply validation middleware 2025-11-23 13:41:07 +01:00
6bb3a78b88 feat: Add ValidationMiddleware to router configuration 2025-11-23 13:40:31 +01:00
54e37e59fc docs: update swagger 2025-11-23 13:35:00 +01:00
5d4b38ddc4 feat: add validation tags to request DTOs 2025-11-23 13:34:53 +01:00
7dc119ecde docs: update swagger 2025-11-23 13:17:14 +01:00
52c9f4a02b feat: add internal/dto to swagger directories 2025-11-23 13:16:44 +01:00
be91a135bc clean: empty line 2025-11-23 13:14:41 +01:00
2d7ff9778b feat: update swagger comments following dtos relocation 2025-11-23 13:14:07 +01:00
4ff3fd3583 refactor: remove UpdatePostRequest definition and update swagger comments 2025-11-23 13:13:53 +01:00
73121cad15 refactor: remove all request DTO, update swagger comments and update token related methods to use dto ones 2025-11-23 13:13:23 +01:00
c5bf1b2fd8 feat: locate post-related request DTOs 2025-11-23 13:12:36 +01:00
eedebe60d1 feat: locate auth-related request DTOs 2025-11-23 13:12:10 +01:00
80fb37371f update: fix go version and update alpine to 3.22 2025-11-23 10:44:28 +01:00
fea49fad8d fix: add missing method to mock 2025-11-21 17:07:26 +01:00
4b04461ebb style: minor formatting adjustments 2025-11-21 17:06:04 +01:00
533e8c3d46 feat: add GetByUsernamePrefixFn field and method to UserRepositoryStub 2025-11-21 17:05:48 +01:00
df568291f1 feat: add GetByUsernamePrefix implementation to MockUserRepository 2025-11-21 17:05:31 +01:00
81acce62b1 feat: add GetByUsernamePrefix method to interface and add implementation 2025-11-21 17:05:01 +01:00
989a61e7d5 feat: use getByUsernamePrefix to optimize findExistingSeedUser() 2025-11-21 17:04:35 +01:00
3ffd83b0fb feat: ignore docs in make format 2025-11-21 17:02:06 +01:00
62d466e4fa refactor: use go generics 2025-11-21 16:56:26 +01:00
0cd428d5d9 feat: use connection pooling instead of a single connection 2025-11-21 16:53:46 +01:00
5c239ad61d feat: add missing GetVoteCountsByPostID method to the errorVoteRepository test mock 2025-11-21 16:50:23 +01:00
01f2b1fe75 feat: remove loop and use GetVoteCountsByPostID 2025-11-21 16:48:48 +01:00
28134c101c feat: add GetVoteCountsByPostID to the mock for testing 2025-11-21 16:48:15 +01:00
2f78370d43 feat: GetVoteCountsByPostID: use a single sql query to returns up votes and down votes counts 2025-11-21 16:47:52 +01:00
39598a166d feat: remove redundat getbyemail call to reduce db query by 2 (1Q/user creation instead of 2) 2025-11-21 16:43:46 +01:00
fa9474d863 revert: db transaction use, avoiding the pgx RETURNING issue while maintaining data consistency 2025-11-21 16:31:06 +01:00
34a97994b3 feat: improve testing to use production code paths and better coverage 2025-11-21 16:26:21 +01:00
eb5f93ffd0 clean: remove duplicate sequential helpers 2025-11-21 16:25:27 +01:00
697f201d60 feat: use database transactions to ensure atomicity 2025-11-21 16:21:04 +01:00
f4ab8bda45 feat: transaction rollback test 2025-11-21 16:20:41 +01:00
65576cc623 feat: keep seeding fast and predictable even when parallelized 2025-11-21 16:16:35 +01:00
a5b4e9bf25 feat: update tests to pass precomputed hashes 2025-11-21 16:11:42 +01:00
c020517ccf feat: reduce hashing cost by removing redundant password hashing 2025-11-21 16:11:33 +01:00
4cdda3f944 feat: remove bcrypt and use a precompute hash 2025-11-21 16:11:08 +01:00
ff471cd5dd fix: loop 2025-11-21 15:39:08 +01:00
df5e67c7f3 feat: add idempotency tests 2025-11-21 15:34:08 +01:00
b2580d2380 feat: make seeding idempotente 2025-11-21 15:33:59 +01:00
4749213bf0 feat: update test to accept randomized seed user identities 2025-11-21 15:26:29 +01:00
6470425b96 feat: avoid unique constraint failures on repeat runs by randomizing seed identities 2025-11-21 15:26:05 +01:00
14ae6f815b feat: update tests to verify clamping 2025-11-21 15:21:05 +01:00
73083e4188 feat: check zero/negative value in seeding 2025-11-21 15:20:57 +01:00
c907c4812b feat: add tests covering negative values for the three flags 2025-11-21 14:56:09 +01:00
c7f30070c0 feat: reject negative/nonsensical flag values with a clear error instead of letting slice/channel allocations panic 2025-11-21 14:55:51 +01:00
0dcd5fec51 fix: Makefile fuzz target to enumerate fuzz functions per package and run each individually 2025-11-21 14:30:37 +01:00
96c054aa99 feat: add new parameter 2025-11-21 13:13:27 +01:00
bdd7766275 docs: update readme 2025-11-21 13:13:18 +01:00
0b1241d371 docs: update readme 2025-11-21 13:11:47 +01:00
b300fc2f5e feat: new test for json output flag from config 2025-11-21 13:11:43 +01:00
f49bea4138 feat(config): read a flag in .env to set or not json output 2025-11-21 13:11:16 +01:00
79e072fe6b feat(cli): read .env value to set or not the json output 2025-11-21 13:10:53 +01:00
7fca1f78dc feat(cli): add a json output and tests 2025-11-21 13:00:03 +01:00
30a2e88685 docs: update readme 2025-11-20 18:55:08 +01:00
10f7220fb6 feat: install pg 18 2025-11-20 18:55:05 +01:00
08a934e388 feat: migrate to postgres 18 2025-11-20 18:54:58 +01:00
87dadfa4a8 feat: modernize loops 2025-11-19 13:13:51 +01:00
53c76eee8b feat: modernize loop 2025-11-17 16:24:40 +01:00
446915d5ee feat: modernize loop 2025-11-17 15:59:24 +01:00
dceb305ac7 feat: modernize statement using max() 2025-11-17 15:37:03 +01:00
e5b1f18beb docs: update readme regarding our new linting configuration 2025-11-15 11:45:54 +01:00
513e0c05b2 feat: creating my own hell by adding a suitable configuration for golangci-lint 2025-11-15 11:45:42 +01:00
ba09bb4141 lint: fix staticcheck nil pointer issue 2025-11-15 11:35:31 +01:00
0f9a89bfc7 lint: add error checks on Register() and fix sqlDB.Close() errcheck 2025-11-15 11:34:57 +01:00
1d1a3dcf60 lint: make linter happy with errcheck 2025-11-15 11:34:12 +01:00
90 changed files with 5513 additions and 3784 deletions

View File

@@ -34,6 +34,9 @@ TITLE=Goyco
DEBUG=false
BCRYPT_COST=10
# CLI configuration
CLI_JSON_OUTPUT=false
# Rate limiting configuration (nb of request per minutes)
RATE_LIMIT_AUTH=10
RATE_LIMIT_GENERAL=200

96
.golangci.yml Normal file
View File

@@ -0,0 +1,96 @@
version: "2"
run:
timeout: 5m
tests: true
modules-download-mode: readonly
linters:
enable:
- govet
- staticcheck
- errcheck
- ineffassign
- unused
- gosec
- bodyclose
- noctx
- sqlclosecheck
- revive
- gocritic
- unparam
- wastedassign
- durationcheck
- goprintffuncname
- goconst
- misspell
- copyloopvar
- nilerr
linters-settings:
revive:
confidence: 0.8
ignore-generated-header: true
severity: warning
rules:
- name: indent-error-flow
- name: early-return
- name: error-naming
- name: var-naming
- name: if-return
- name: exported
- name: context-as-argument
- name: errorf
- name: unexported-return
- name: unnecessary-stmt
- name: unreachable-code
gocritic:
disabled-checks:
- hugeParam
- ifElseChain
settings:
captLocal:
paramsOnly: false
gosec:
excludes:
- G204
- G601
errcheck:
check-type-assertions: true
ignore: |
fmt:.*
staticcheck:
checks: ["all"]
ignore:
- SA1019
goconst:
min-len: 3
min-occurrences: 2
issues:
exclude-dirs:
- vendor
- tmp
- dist
exclude-rules:
- path: ".*_test\\.go"
linters:
- gosec
- path: "mock/|mocks/"
linters:
- gosec
- revive
- errcheck
max-issues-per-linter: 0
max-same-issues: 0
output:
format: colored-line-number

1
.prettierignore Normal file
View File

@@ -0,0 +1 @@
docs/

View File

@@ -1,4 +1,4 @@
ARG GO_VERSION=1.25.3
ARG GO_VERSION=1.25.4
# Building the binary using a golang alpine image
FROM golang:${GO_VERSION}-alpine AS go-builder
@@ -11,7 +11,7 @@ ARG TARGETARCH=amd64
RUN CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build -ldflags="-s -w" -o /out/goyco ./cmd/goyco
# building the application image
FROM alpine:3.21
FROM alpine:3.23
RUN addgroup -S goyco && adduser -S -G goyco goyco \
&& apk add --no-cache ca-certificates tzdata
WORKDIR /app

View File

@@ -44,8 +44,8 @@ clean:
rm -fr dist/*
format:
$(PRETTIER) -w .
$(GO) fmt ./...
$(PRETTIER) -w . --ignore-path .prettierignore
$(GO) fmt $(shell $(GO) list ./... | grep -v 'docs')
lint:
$(GOLANGCI_LINT) run
@@ -77,7 +77,15 @@ fuzz-tests:
@set -e; \
for pkg in $(FUZZ_PACKAGES); do \
echo "==> Fuzzing $$pkg"; \
$(GO) test -fuzz=. -fuzztime=$(FUZZ_TIME) $$pkg; \
fuzz_targets="$$( $(GO) test -run ^$$ -list ^Fuzz $$pkg | grep '^Fuzz' || true )"; \
if [ -z "$$fuzz_targets" ]; then \
echo "No fuzz tests found in $$pkg"; \
continue; \
fi; \
for fuzz in $$fuzz_targets; do \
echo " -> $$fuzz"; \
$(GO) test -run ^$$ -fuzz="^$$fuzz$$" -fuzztime=$(FUZZ_TIME) $$pkg; \
done; \
done
install:

View File

@@ -1,7 +1,7 @@
# Goyco
[![Go Version](https://img.shields.io/badge/Go-1.25.0-blue.svg)](https://golang.org/)
[![PostgreSQL](https://img.shields.io/badge/PostgreSQL-17-blue.svg)](https://www.postgresql.org/)
[![PostgreSQL](https://img.shields.io/badge/PostgreSQL-18-blue.svg)](https://www.postgresql.org/)
[![License](https://img.shields.io/badge/License-GPLv3-green.svg)](LICENSE)
Goyco is a Y Combinator-style news aggregation platform built in Go. It will allow you to host your own news aggregation platform, with a modern-ish UI and a fully functional REST API.
@@ -18,14 +18,14 @@ You can get a preview of the application through the [screenshots](screenshots).
### Technology Stack
It's basically pure Go (using Chi router), raw CSS and PostgreSQL 17.
It's basically pure Go (using Chi router), raw CSS and PostgreSQL 18.
## Quick Start
### Prerequisites
- Go 1.25.0 or later
- PostgreSQL 17 or later
- PostgreSQL 18 or later
- SMTP server for email functionality
### Setup PostgreSQL database and user
@@ -251,6 +251,37 @@ Goyco includes a comprehensive CLI for administration:
./bin/goyco prune all # Hard delete all users and posts
```
### JSON Output
All CLI commands support JSON output for easier parsing and integration with scripts. Use the `--json` flag to enable structured JSON output:
```bash
# Get JSON output
./bin/goyco --json user list
./bin/goyco --json post list
./bin/goyco --json status
# Example: Parse JSON output with jq
./bin/goyco --json user list | jq '.users[0].username'
./bin/goyco --json status | jq '.status'
```
You can also set JSON output as the default by setting the `CLI_JSON_OUTPUT` environment variable to `true` in your `.env` file:
```bash
CLI_JSON_OUTPUT=true
```
When set, all CLI commands will output JSON by default.
**Note for destructive operations**: When using `--json` with `prune` commands, you must also use the `--yes` flag to skip interactive confirmation prompts:
```bash
./bin/goyco --json prune posts --yes
./bin/goyco --json prune users --yes --with-posts
./bin/goyco --json prune all --yes
```
## Development
### Get the sources
@@ -386,6 +417,8 @@ make format
make lint
```
Note: `golangci-lint` is set up with `.golangci.yml` file.
### Regerenate Swagger documentation
If you make changes to the API, you can regenerate the swagger documentation by running the following command after modifying the swagger annotations:
@@ -404,31 +437,10 @@ This will regenerate the swagger documentation and update the `docs/swagger.json
- [ ] add right management within the app
- [ ] add an admin backoffice to manage rights, users, content and settings
- [ ] add a way to run read-only communities
- [ ] maybe use a css framework instead of raw css
- [ ] migrate raw CSS to UnoCSS
- [ ] kubernetes deployment
- [ ] store configuration in the database
## Contributing
Feedbacks are welcome!
But as it's a personal gitea and you cannot create accounts, feel free to contact me at <sandro@cazzaniga.fr> to get one.
Once you have it, follow the usual workflow:
1. Fork the repository
2. Create a feature branch
3. Make your changes
4. Add tests for new functionality
5. Ensure all tests pass
6. Submit a pull request
Then, I'll review your changes and merge them if they are good.
## License
This project is licensed under the GNU General Public License v3.0 or later (GPLv3+). See the [LICENSE](LICENSE) file for details.
---
**Goyco** - A modern news aggregation platform built with Go, PostgreSQL and most importantly, love.

View File

@@ -6,8 +6,9 @@ import (
"fmt"
"os"
"github.com/joho/godotenv"
"goyco/cmd/goyco/commands"
"github.com/joho/godotenv"
)
func loadDotEnv() {

View File

@@ -1,32 +1,51 @@
package commands
import (
"encoding/json"
"errors"
"flag"
"fmt"
"os"
"sync"
"gorm.io/gorm"
"goyco/internal/config"
"goyco/internal/database"
"gorm.io/gorm"
)
var ErrHelpRequested = errors.New("help requested")
type DBConnector func(cfg *config.Config) (*gorm.DB, func() error, error)
var (
jsonOutputMu sync.RWMutex
jsonOutput bool
)
func SetJSONOutput(enabled bool) {
jsonOutputMu.Lock()
defer jsonOutputMu.Unlock()
jsonOutput = enabled
}
func IsJSONOutput() bool {
jsonOutputMu.RLock()
defer jsonOutputMu.RUnlock()
return jsonOutput
}
var (
dbConnectorMu sync.RWMutex
currentDBConnector = defaultDBConnector
)
func defaultDBConnector(cfg *config.Config) (*gorm.DB, func() error, error) {
db, err := database.Connect(cfg)
poolManager, err := database.ConnectWithPool(cfg)
if err != nil {
return nil, nil, err
}
return db, func() error { return database.Close(db) }, nil
return poolManager.GetDB(), func() error { return poolManager.Close() }, nil
}
func SetDBConnector(connector DBConnector) {
@@ -93,3 +112,19 @@ func truncate(in string, max int) string {
}
return in[:max-3] + "..."
}
func outputJSON(v interface{}) error {
encoder := json.NewEncoder(os.Stdout)
encoder.SetIndent("", " ")
return encoder.Encode(v)
}
func outputWarning(message string, args ...interface{}) {
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"warning": fmt.Sprintf(message, args...),
})
} else {
fmt.Printf("Warning: "+message+"\n", args...)
}
}

View File

@@ -217,3 +217,41 @@ func setInMemoryDBConnector(t *testing.T) {
SetDBConnector(nil)
})
}
func TestSetJSONOutput(t *testing.T) {
t.Run("set and get JSON output", func(t *testing.T) {
SetJSONOutput(true)
if !IsJSONOutput() {
t.Error("expected JSON output to be enabled")
}
SetJSONOutput(false)
if IsJSONOutput() {
t.Error("expected JSON output to be disabled")
}
})
t.Run("concurrent access", func(t *testing.T) {
SetJSONOutput(false)
done := make(chan bool)
go func() {
for i := 0; i < 100; i++ {
SetJSONOutput(true)
_ = IsJSONOutput()
SetJSONOutput(false)
}
done <- true
}()
go func() {
for i := 0; i < 100; i++ {
_ = IsJSONOutput()
}
done <- true
}()
<-done
<-done
})
}

View File

@@ -86,23 +86,50 @@ func runStatusCommand(cfg *config.Config) error {
pidFile := filepath.Join(pidDir, "goyco.pid")
if !isDaemonRunning(pidFile) {
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"status": "not_running",
})
} else {
fmt.Println("Goyco is not running")
}
return nil
}
data, err := os.ReadFile(pidFile)
if err != nil {
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"status": "running",
"error": fmt.Sprintf("PID file exists but cannot be read: %v", err),
})
} else {
fmt.Printf("Goyco is running (PID file exists but cannot be read: %v)\n", err)
}
return nil
}
pid, err := strconv.Atoi(string(data))
if err != nil {
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"status": "running",
"error": fmt.Sprintf("PID file exists but contains invalid PID: %v", err),
})
} else {
fmt.Printf("Goyco is running (PID file exists but contains invalid PID: %v)\n", err)
}
return nil
}
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"status": "running",
"pid": pid,
})
} else {
fmt.Printf("Goyco is running (PID %d)\n", pid)
}
return nil
}
@@ -143,7 +170,14 @@ func stopDaemon(cfg *config.Config) error {
_ = os.Remove(pidFile)
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"action": "stopped",
"pid": pid,
})
} else {
fmt.Printf("Goyco stopped (PID %d)\n", pid)
}
return nil
}
@@ -184,9 +218,18 @@ func runDaemon(cfg *config.Config) error {
if err := writePIDFile(pidFile, pid); err != nil {
return fmt.Errorf("cannot write PID file: %w", err)
}
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"action": "started",
"pid": pid,
"pid_file": pidFile,
"log_file": logFile,
})
} else {
fmt.Printf("Goyco started with PID %d\n", pid)
fmt.Printf("PID file: %s\n", pidFile)
fmt.Printf("Log file: %s\n", logFile)
}
return nil
}

View File

@@ -100,6 +100,20 @@ func TestRunStatusCommand(t *testing.T) {
}
})
t.Run("daemon not running with JSON output", func(t *testing.T) {
SetJSONOutput(true)
defer SetJSONOutput(false)
tempDir := t.TempDir()
cfg.PIDDir = tempDir
err := runStatusCommand(cfg)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("daemon running with valid PID", func(t *testing.T) {
tempDir := t.TempDir()
cfg.PIDDir = tempDir
@@ -118,6 +132,27 @@ func TestRunStatusCommand(t *testing.T) {
}
})
t.Run("daemon running with valid PID and JSON output", func(t *testing.T) {
SetJSONOutput(true)
defer SetJSONOutput(false)
tempDir := t.TempDir()
cfg.PIDDir = tempDir
pidFile := filepath.Join(tempDir, "goyco.pid")
currentPID := os.Getpid()
err := os.WriteFile(pidFile, []byte(strconv.Itoa(currentPID)), 0644)
if err != nil {
t.Fatalf("Failed to create PID file: %v", err)
}
err = runStatusCommand(cfg)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("daemon running with invalid PID file", func(t *testing.T) {
tempDir := t.TempDir()
cfg.PIDDir = tempDir

View File

@@ -5,9 +5,10 @@ import (
"fmt"
"os"
"gorm.io/gorm"
"goyco/internal/config"
"goyco/internal/database"
"gorm.io/gorm"
)
func HandleMigrateCommand(cfg *config.Config, name string, args []string) error {
@@ -30,11 +31,20 @@ func HandleMigrateCommand(cfg *config.Config, name string, args []string) error
}
func runMigrateCommand(db *gorm.DB) error {
if !IsJSONOutput() {
fmt.Println("Running database migrations...")
}
if err := database.Migrate(db); err != nil {
return fmt.Errorf("run migrations: %w", err)
}
if IsJSONOutput() {
outputJSON(map[string]any{
"action": "migrations_applied",
"status": "success",
})
} else {
fmt.Println("Migrations applied successfully")
}
return nil
}

View File

@@ -39,4 +39,18 @@ func TestHandleMigrateCommand(t *testing.T) {
t.Fatalf("unexpected error running migrations: %v", err)
}
})
t.Run("runs migrations with JSON output", func(t *testing.T) {
SetJSONOutput(true)
defer SetJSONOutput(false)
cfg := testutils.NewTestConfig()
setInMemoryDBConnector(t)
err := HandleMigrateCommand(cfg, "migrate", []string{})
if err != nil {
t.Fatalf("unexpected error running migrations: %v", err)
}
})
}

View File

@@ -2,14 +2,13 @@ package commands
import (
"context"
"crypto/rand"
cryptoRand "crypto/rand"
"fmt"
"math/big"
"math/rand"
"runtime"
"sync"
"time"
"golang.org/x/crypto/bcrypt"
"goyco/internal/database"
"goyco/internal/repositories"
)
@@ -17,93 +16,163 @@ import (
type ParallelProcessor struct {
maxWorkers int
timeout time.Duration
passwordHash string
randSource *rand.Rand
randMu sync.Mutex
}
func NewParallelProcessor() *ParallelProcessor {
maxWorkers := max(min(runtime.NumCPU(), 8), 2)
seed := time.Now().UnixNano()
seedBytes := make([]byte, 8)
if _, err := cryptoRand.Read(seedBytes); err == nil {
seed = int64(seedBytes[0])<<56 | int64(seedBytes[1])<<48 | int64(seedBytes[2])<<40 | int64(seedBytes[3])<<32 |
int64(seedBytes[4])<<24 | int64(seedBytes[5])<<16 | int64(seedBytes[6])<<8 | int64(seedBytes[7])
}
return &ParallelProcessor{
maxWorkers: maxWorkers,
timeout: 30 * time.Second,
timeout: 60 * time.Second,
randSource: rand.New(rand.NewSource(seed)),
}
}
func (p *ParallelProcessor) SetPasswordHash(hash string) {
p.passwordHash = hash
}
type indexedResult[T any] struct {
value T
index int
}
func processInParallel[T any](
ctx context.Context,
maxWorkers int,
count int,
processor func(index int) (T, error),
errorPrefix string,
progress *ProgressIndicator,
) ([]T, error) {
results := make(chan indexedResult[T], count)
errors := make(chan error, count)
semaphore := make(chan struct{}, maxWorkers)
var wg sync.WaitGroup
for i := range count {
wg.Add(1)
go func(index int) {
defer wg.Done()
select {
case semaphore <- struct{}{}:
case <-ctx.Done():
errors <- ctx.Err()
return
}
defer func() { <-semaphore }()
value, err := processor(index + 1)
if err != nil {
errors <- fmt.Errorf("%s %d: %w", errorPrefix, index+1, err)
return
}
results <- indexedResult[T]{value: value, index: index}
}(i)
}
go func() {
wg.Wait()
close(results)
close(errors)
}()
items := make([]T, count)
completed := 0
firstError := make(chan error, 1)
go func() {
for err := range errors {
if err != nil {
select {
case firstError <- err:
default:
}
return
}
}
}()
for completed < count {
select {
case result, ok := <-results:
if !ok {
return items, nil
}
items[result.index] = result.value
completed++
if progress != nil {
progress.Update(completed)
}
case err := <-firstError:
return nil, err
case <-ctx.Done():
return nil, fmt.Errorf("timeout: %w", ctx.Err())
}
}
return items, nil
}
func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepository, count int, progress *ProgressIndicator) ([]database.User, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
results := make(chan userResult, count)
errors := make(chan error, count)
semaphore := make(chan struct{}, p.maxWorkers)
var wg sync.WaitGroup
for i := range count {
wg.Add(1)
go func(index int) {
defer wg.Done()
select {
case semaphore <- struct{}{}:
case <-ctx.Done():
errors <- ctx.Err()
return
}
defer func() { <-semaphore }()
user, err := p.createSingleUser(userRepo, index+1)
if err != nil {
errors <- fmt.Errorf("create user %d: %w", index+1, err)
return
}
results <- userResult{user: user, index: index}
}(i)
}
go func() {
wg.Wait()
close(results)
close(errors)
}()
users := make([]database.User, count)
completed := 0
for {
select {
case result, ok := <-results:
if !ok {
return users, nil
}
users[result.index] = result.user
completed++
if progress != nil {
progress.Update(completed)
}
case err := <-errors:
if err != nil {
return nil, err
}
case <-ctx.Done():
return nil, fmt.Errorf("timeout creating users: %w", ctx.Err())
}
}
return processInParallel(ctx, p.maxWorkers, count,
func(index int) (database.User, error) {
return p.createSingleUser(userRepo, index)
},
"create user",
progress,
)
}
func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepository, authorID uint, count int, progress *ProgressIndicator) ([]database.Post, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
results := make(chan postResult, count)
return processInParallel(ctx, p.maxWorkers, count,
func(index int) (database.Post, error) {
return p.createSinglePost(postRepo, authorID, index)
},
"create post",
progress,
)
}
func processItemsInParallel[T any, R any](
ctx context.Context,
maxWorkers int,
items []T,
processor func(index int, item T) (R, error),
errorPrefix string,
aggregator func(accumulator R, value R) R,
initialValue R,
progress *ProgressIndicator,
) (R, error) {
count := len(items)
results := make(chan indexedResult[R], count)
errors := make(chan error, count)
semaphore := make(chan struct{}, p.maxWorkers)
semaphore := make(chan struct{}, maxWorkers)
var wg sync.WaitGroup
for i := range count {
for i, item := range items {
wg.Add(1)
go func(index int) {
go func(index int, item T) {
defer wg.Done()
select {
@@ -114,14 +183,14 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo
}
defer func() { <-semaphore }()
post, err := p.createSinglePost(postRepo, authorID, index+1)
value, err := processor(index, item)
if err != nil {
errors <- fmt.Errorf("create post %d: %w", index+1, err)
errors <- fmt.Errorf("%s %d: %w", errorPrefix, index+1, err)
return
}
results <- postResult{post: post, index: index}
}(i)
results <- indexedResult[R]{value: value, index: index}
}(i, item)
}
go func() {
@@ -130,43 +199,76 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo
close(errors)
}()
posts := make([]database.Post, count)
accumulator := initialValue
completed := 0
firstError := make(chan error, 1)
for {
go func() {
for err := range errors {
if err != nil {
select {
case firstError <- err:
default:
}
return
}
}
}()
for completed < count {
select {
case result, ok := <-results:
if !ok {
return posts, nil
return accumulator, nil
}
posts[result.index] = result.post
accumulator = aggregator(accumulator, result.value)
completed++
if progress != nil {
progress.Update(completed)
}
case err := <-errors:
if err != nil {
return nil, err
}
case err := <-firstError:
return initialValue, err
case <-ctx.Done():
return nil, fmt.Errorf("timeout creating posts: %w", ctx.Err())
return initialValue, fmt.Errorf("timeout: %w", ctx.Err())
}
}
return accumulator, nil
}
func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post, avgVotesPerPost int, progress *ProgressIndicator) (int, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
results := make(chan voteResult, len(posts))
errors := make(chan error, len(posts))
return processItemsInParallel(ctx, p.maxWorkers, posts,
func(index int, post database.Post) (int, error) {
return p.createVotesForPost(voteRepo, users, post, avgVotesPerPost)
},
"create votes for post",
func(acc, val int) int { return acc + val },
0,
progress,
)
}
semaphore := make(chan struct{}, p.maxWorkers)
func processItemsInParallelNoResult[T any](
ctx context.Context,
maxWorkers int,
items []T,
processor func(index int, item T) error,
errorFormatter func(index int, item T, err error) error,
progress *ProgressIndicator,
) error {
count := len(items)
errors := make(chan error, count)
completions := make(chan struct{}, count)
semaphore := make(chan struct{}, maxWorkers)
var wg sync.WaitGroup
for i, post := range posts {
for i, item := range items {
wg.Add(1)
go func(index int, post database.Post) {
go func(index int, item T) {
defer wg.Done()
select {
@@ -177,131 +279,113 @@ func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepo
}
defer func() { <-semaphore }()
votes, err := p.createVotesForPost(voteRepo, users, post, avgVotesPerPost)
err := processor(index, item)
if err != nil {
errors <- fmt.Errorf("create votes for post %d: %w", post.ID, err)
if errorFormatter != nil {
errors <- errorFormatter(index, item, err)
} else {
errors <- fmt.Errorf("process item %d: %w", index+1, err)
}
return
}
results <- voteResult{votes: votes, index: index}
}(i, post)
completions <- struct{}{}
}(i, item)
}
go func() {
wg.Wait()
close(results)
close(errors)
close(completions)
}()
totalVotes := 0
completed := 0
firstError := make(chan error, 1)
for {
go func() {
for err := range errors {
if err != nil {
select {
case result, ok := <-results:
if !ok {
return totalVotes, nil
case firstError <- err:
default:
}
return
}
}
}()
for completed < count {
select {
case _, ok := <-completions:
if !ok {
return nil
}
totalVotes += result.votes
completed++
if progress != nil {
progress.Update(completed)
}
case err := <-errors:
if err != nil {
return 0, err
}
case <-ctx.Done():
return 0, fmt.Errorf("timeout creating votes: %w", ctx.Err())
}
}
}
func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, progress *ProgressIndicator) error {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
errors := make(chan error, len(posts))
semaphore := make(chan struct{}, p.maxWorkers)
var wg sync.WaitGroup
for i, post := range posts {
wg.Add(1)
go func(index int, post database.Post) {
defer wg.Done()
select {
case semaphore <- struct{}{}:
case <-ctx.Done():
errors <- ctx.Err()
return
}
defer func() { <-semaphore }()
err := p.updateSinglePostScore(postRepo, voteRepo, post)
if err != nil {
errors <- fmt.Errorf("update post %d scores: %w", post.ID, err)
return
}
if progress != nil {
progress.Update(index + 1)
}
}(i, post)
}
go func() {
wg.Wait()
close(errors)
}()
for err := range errors {
if err != nil {
case err := <-firstError:
return err
case <-ctx.Done():
return fmt.Errorf("timeout: %w", ctx.Err())
}
}
return nil
}
type userResult struct {
user database.User
index int
func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, progress *ProgressIndicator) error {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
return processItemsInParallelNoResult(ctx, p.maxWorkers, posts,
func(index int, post database.Post) error {
return p.updateSinglePostScore(postRepo, voteRepo, post)
},
func(index int, post database.Post, err error) error {
return fmt.Errorf("update post %d scores: %w", post.ID, err)
},
progress,
)
}
type postResult struct {
post database.Post
index int
}
type voteResult struct {
votes int
index int
func (p *ParallelProcessor) generateRandomIdentifier() string {
const length = 12
const chars = "abcdefghijklmnopqrstuvwxyz0123456789"
identifier := make([]byte, length)
p.randMu.Lock()
for i := range identifier {
identifier[i] = chars[p.randSource.Intn(len(chars))]
}
p.randMu.Unlock()
return string(identifier)
}
func (p *ParallelProcessor) createSingleUser(userRepo repositories.UserRepository, index int) (database.User, error) {
username := fmt.Sprintf("user_%d", index)
email := fmt.Sprintf("user_%d@goyco.local", index)
password := "password123"
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return database.User{}, fmt.Errorf("hash password: %w", err)
}
randomID := p.generateRandomIdentifier()
username := fmt.Sprintf("user_%s", randomID)
email := fmt.Sprintf("user_%s@goyco.local", randomID)
const maxRetries = 10
for range maxRetries {
user := &database.User{
Username: username,
Email: email,
Password: string(hashedPassword),
Password: p.passwordHash,
EmailVerified: true,
}
if err := userRepo.Create(user); err != nil {
return database.User{}, fmt.Errorf("create user: %w", err)
randomID = p.generateRandomIdentifier()
username = fmt.Sprintf("user_%s", randomID)
email = fmt.Sprintf("user_%s@goyco.local", randomID)
continue
}
return *user, nil
}
return database.User{}, fmt.Errorf("failed to create user after %d attempts", maxRetries)
}
func (p *ParallelProcessor) createSinglePost(postRepo repositories.PostRepository, authorID uint, index int) (database.Post, error) {
@@ -347,11 +431,14 @@ func (p *ParallelProcessor) createSinglePost(postRepo repositories.PostRepositor
}
domain := sampleDomains[index%len(sampleDomains)]
path := generateRandomPath()
randomID := p.generateRandomIdentifier()
path := fmt.Sprintf("/article/%s", randomID)
url := fmt.Sprintf("https://%s%s", domain, path)
content := fmt.Sprintf("Autogenerated seed post #%d\n\nThis is sample content for testing purposes. The post discusses %s and provides valuable insights.", index, title)
const maxRetries = 10
for range maxRetries {
post := &database.Post{
Title: title,
URL: url,
@@ -363,38 +450,50 @@ func (p *ParallelProcessor) createSinglePost(postRepo repositories.PostRepositor
}
if err := postRepo.Create(post); err != nil {
return database.Post{}, fmt.Errorf("create post: %w", err)
randomID = p.generateRandomIdentifier()
path = fmt.Sprintf("/article/%s", randomID)
url = fmt.Sprintf("https://%s%s", domain, path)
continue
}
return *post, nil
}
return database.Post{}, fmt.Errorf("failed to create post after %d attempts", maxRetries)
}
func (p *ParallelProcessor) createVotesForPost(voteRepo repositories.VoteRepository, users []database.User, post database.Post, avgVotesPerPost int) (int, error) {
voteCount, _ := rand.Int(rand.Reader, big.NewInt(int64(avgVotesPerPost*2)+1))
numVotes := int(voteCount.Int64())
p.randMu.Lock()
numVotes := p.randSource.Intn(avgVotesPerPost*2 + 1)
p.randMu.Unlock()
if numVotes == 0 && avgVotesPerPost > 0 {
chance, _ := rand.Int(rand.Reader, big.NewInt(5))
if chance.Int64() > 0 {
p.randMu.Lock()
if p.randSource.Intn(5) > 0 {
numVotes = 1
}
p.randMu.Unlock()
}
totalVotes := 0
usedUsers := make(map[uint]bool)
for i := 0; i < numVotes && len(usedUsers) < len(users); i++ {
userIdx, _ := rand.Int(rand.Reader, big.NewInt(int64(len(users))))
user := users[userIdx.Int64()]
p.randMu.Lock()
userIdx := p.randSource.Intn(len(users))
p.randMu.Unlock()
user := users[userIdx]
if usedUsers[user.ID] {
continue
}
usedUsers[user.ID] = true
voteTypeInt, _ := rand.Int(rand.Reader, big.NewInt(10))
p.randMu.Lock()
voteTypeInt := p.randSource.Intn(10)
p.randMu.Unlock()
var voteType database.VoteType
if voteTypeInt.Int64() < 7 {
if voteTypeInt < 7 {
voteType = database.VoteUp
} else {
voteType = database.VoteDown
@@ -406,8 +505,8 @@ func (p *ParallelProcessor) createVotesForPost(voteRepo repositories.VoteReposit
Type: voteType,
}
if err := voteRepo.Create(vote); err != nil {
return totalVotes, fmt.Errorf("create vote: %w", err)
if err := voteRepo.CreateOrUpdate(vote); err != nil {
return totalVotes, fmt.Errorf("create or update vote: %w", err)
}
totalVotes++

View File

@@ -2,15 +2,15 @@ package commands_test
import (
"errors"
"fmt"
"sync"
"testing"
"golang.org/x/crypto/bcrypt"
"goyco/cmd/goyco/commands"
"goyco/internal/database"
"goyco/internal/repositories"
"goyco/internal/testutils"
"golang.org/x/crypto/bcrypt"
)
func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
@@ -25,7 +25,7 @@ func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
wantErr bool
}{
{
name: "creates users with deterministic fields",
name: "creates users with required fields",
count: successCount,
repoFactory: func() repositories.UserRepository {
base := testutils.NewMockUserRepository()
@@ -37,14 +37,24 @@ func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
if len(got) != successCount {
t.Fatalf("expected %d users, got %d", successCount, len(got))
}
usernames := make(map[string]bool)
for i, user := range got {
expectedUsername := fmt.Sprintf("user_%d", i+1)
expectedEmail := fmt.Sprintf("user_%d@goyco.local", i+1)
if user.Username != expectedUsername {
t.Errorf("user %d username mismatch: got %q want %q", i, user.Username, expectedUsername)
if user.Username == "" {
t.Errorf("user %d expected non-empty username", i)
}
if user.Email != expectedEmail {
t.Errorf("user %d email mismatch: got %q want %q", i, user.Email, expectedEmail)
if len(user.Username) < 6 || user.Username[:5] != "user_" {
t.Errorf("user %d username should start with 'user_', got %q", i, user.Username)
}
if usernames[user.Username] {
t.Errorf("user %d duplicate username: %q", i, user.Username)
}
usernames[user.Username] = true
if user.Email == "" {
t.Errorf("user %d expected non-empty email", i)
}
if len(user.Email) < 20 || user.Email[:5] != "user_" || user.Email[len(user.Email)-12:] != "@goyco.local" {
t.Errorf("user %d email should match pattern 'user_*@goyco.local', got %q", i, user.Email)
}
if !user.EmailVerified {
t.Errorf("user %d expected EmailVerified to be true", i)
@@ -83,6 +93,11 @@ func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
t.Parallel()
repo := tt.repoFactory()
p := commands.NewParallelProcessor()
passwordHash, err := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
if err != nil {
t.Fatalf("failed to generate password hash: %v", err)
}
p.SetPasswordHash(string(passwordHash))
got, gotErr := p.CreateUsersInParallel(repo, tt.count, tt.progress)
if gotErr != nil {
if !tt.wantErr {

View File

@@ -90,7 +90,14 @@ func postDelete(repo repositories.PostRepository, args []string) error {
return fmt.Errorf("delete post: %w", err)
}
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"action": "post_deleted",
"id": id,
})
} else {
fmt.Printf("Post deleted: ID=%d\n", id)
}
return nil
}
@@ -126,6 +133,38 @@ func postList(postQueries *services.PostQueries, args []string) error {
return fmt.Errorf("list posts: %w", err)
}
if IsJSONOutput() {
type postJSON struct {
ID uint `json:"id"`
Title string `json:"title"`
AuthorID uint `json:"author_id"`
Score int `json:"score"`
CreatedAt string `json:"created_at"`
}
postsJSON := make([]postJSON, len(posts))
for i, p := range posts {
authorID := uint(0)
if p.AuthorID != nil {
authorID = *p.AuthorID
}
if p.Author.ID != 0 {
authorID = p.Author.ID
}
postsJSON[i] = postJSON{
ID: p.ID,
Title: p.Title,
AuthorID: authorID,
Score: p.Score,
CreatedAt: p.CreatedAt.Format("2006-01-02 15:04:05"),
}
}
outputJSON(map[string]interface{}{
"posts": postsJSON,
"count": len(postsJSON),
})
return nil
}
if len(posts) == 0 {
fmt.Println("No posts found")
return nil
@@ -234,6 +273,39 @@ func postSearch(postQueries *services.PostQueries, args []string) error {
return fmt.Errorf("search posts: %w", err)
}
if IsJSONOutput() {
type postJSON struct {
ID uint `json:"id"`
Title string `json:"title"`
AuthorID uint `json:"author_id"`
Score int `json:"score"`
CreatedAt string `json:"created_at"`
}
postsJSON := make([]postJSON, len(posts))
for i, p := range posts {
authorID := uint(0)
if p.AuthorID != nil {
authorID = *p.AuthorID
}
if p.Author.ID != 0 {
authorID = p.Author.ID
}
postsJSON[i] = postJSON{
ID: p.ID,
Title: p.Title,
AuthorID: authorID,
Score: p.Score,
CreatedAt: p.CreatedAt.Format("2006-01-02 15:04:05"),
}
}
outputJSON(map[string]interface{}{
"search_term": sanitizedTerm,
"posts": postsJSON,
"count": len(postsJSON),
})
return nil
}
if len(posts) == 0 {
fmt.Println("No posts found matching your search")
return nil

View File

@@ -89,6 +89,26 @@ func TestPostDelete(t *testing.T) {
}
})
t.Run("successful delete with JSON output", func(t *testing.T) {
SetJSONOutput(true)
defer SetJSONOutput(false)
freshMockRepo := testutils.NewMockPostRepository()
testPost := &database.Post{
Title: "Test Post",
Content: "Test Content",
AuthorID: &[]uint{1}[0],
Score: 0,
}
_ = freshMockRepo.Create(testPost)
err := postDelete(freshMockRepo, []string{"1"})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("missing id", func(t *testing.T) {
err := postDelete(mockRepo, []string{})
@@ -174,6 +194,17 @@ func TestPostList(t *testing.T) {
}
})
t.Run("list all posts with JSON output", func(t *testing.T) {
SetJSONOutput(true)
defer SetJSONOutput(false)
err := postList(postQueries, []string{})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("list with limit", func(t *testing.T) {
err := postList(postQueries, []string{"--limit", "1"})
@@ -271,6 +302,17 @@ func TestPostSearch(t *testing.T) {
}
})
t.Run("search with results and JSON output", func(t *testing.T) {
SetJSONOutput(true)
defer SetJSONOutput(false)
err := postSearch(postQueries, []string{"Go"})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("case insensitive search", func(t *testing.T) {
mockRepo.SearchCalls = nil

View File

@@ -62,6 +62,7 @@ func printPruneUsage() {
func prunePosts(postRepo repositories.PostRepository, args []string) error {
fs := flag.NewFlagSet("prune posts", flag.ContinueOnError)
dryRun := fs.Bool("dry-run", false, "preview what would be deleted without actually deleting")
yes := fs.Bool("yes", false, "skip confirmation prompt")
fs.SetOutput(os.Stderr)
if err := fs.Parse(args); err != nil {
@@ -73,6 +74,49 @@ func prunePosts(postRepo repositories.PostRepository, args []string) error {
return fmt.Errorf("get posts by deleted users: %w", err)
}
if IsJSONOutput() {
type postJSON struct {
ID uint `json:"id"`
Title string `json:"title"`
Author string `json:"author"`
URL string `json:"url"`
}
postsJSON := make([]postJSON, len(posts))
for i, post := range posts {
authorName := "(deleted)"
if post.Author.ID != 0 {
authorName = post.Author.Username
}
postsJSON[i] = postJSON{
ID: post.ID,
Title: post.Title,
Author: authorName,
URL: post.URL,
}
}
if *dryRun {
outputJSON(map[string]interface{}{
"action": "prune_posts",
"dry_run": true,
"posts": postsJSON,
"count": len(postsJSON),
})
return nil
}
if !*yes {
return fmt.Errorf("confirmation required. Use --yes to skip prompt in JSON mode")
}
deletedCount, err := postRepo.HardDeletePostsByDeletedUsers()
if err != nil {
return fmt.Errorf("hard delete posts: %w", err)
}
outputJSON(map[string]interface{}{
"action": "prune_posts",
"deleted_count": deletedCount,
})
return nil
}
if len(posts) == 0 {
fmt.Println("No posts found for deleted users")
return nil
@@ -93,6 +137,7 @@ func prunePosts(postRepo repositories.PostRepository, args []string) error {
return nil
}
if !*yes {
fmt.Printf("\nAre you sure you want to permanently delete %d posts? (yes/no): ", len(posts))
var confirmation string
if _, err := fmt.Scanln(&confirmation); err != nil {
@@ -103,6 +148,7 @@ func prunePosts(postRepo repositories.PostRepository, args []string) error {
fmt.Println("Operation cancelled")
return nil
}
}
deletedCount, err := postRepo.HardDeletePostsByDeletedUsers()
if err != nil {
@@ -117,6 +163,7 @@ func pruneUsers(userRepo repositories.UserRepository, postRepo repositories.Post
fs := flag.NewFlagSet("prune users", flag.ContinueOnError)
dryRun := fs.Bool("dry-run", false, "preview what would be deleted without actually deleting")
deletePosts := fs.Bool("with-posts", false, "also delete all posts when deleting users")
yes := fs.Bool("yes", false, "skip confirmation prompt")
fs.SetOutput(os.Stderr)
if err := fs.Parse(args); err != nil {
@@ -130,7 +177,14 @@ func pruneUsers(userRepo repositories.UserRepository, postRepo repositories.Post
userCount := len(users)
if userCount == 0 {
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"action": "prune_users",
"count": 0,
})
} else {
fmt.Println("No users found to delete")
}
return nil
}
@@ -142,6 +196,61 @@ func pruneUsers(userRepo repositories.UserRepository, postRepo repositories.Post
}
}
if IsJSONOutput() {
type userJSON struct {
ID uint `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
}
usersJSON := make([]userJSON, len(users))
for i, user := range users {
usersJSON[i] = userJSON{
ID: user.ID,
Username: user.Username,
Email: user.Email,
}
}
if *dryRun {
outputJSON(map[string]interface{}{
"action": "prune_users",
"dry_run": true,
"users": usersJSON,
"user_count": userCount,
"post_count": postCount,
"with_posts": *deletePosts,
})
return nil
}
if !*yes {
return fmt.Errorf("confirmation required. Use --yes to skip prompt or --json for non-interactive mode")
}
if *deletePosts {
totalDeleted, err := userRepo.HardDeleteAll()
if err != nil {
return fmt.Errorf("hard delete all users and posts: %w", err)
}
outputJSON(map[string]interface{}{
"action": "prune_users",
"deleted_count": totalDeleted,
"with_posts": true,
})
} else {
deletedCount := 0
for _, user := range users {
if err := userRepo.SoftDeleteWithPosts(user.ID); err != nil {
return fmt.Errorf("soft delete user %d: %w", user.ID, err)
}
deletedCount++
}
outputJSON(map[string]interface{}{
"action": "prune_users",
"deleted_count": deletedCount,
"with_posts": false,
})
}
return nil
}
fmt.Printf("Found %d users", userCount)
if *deletePosts {
fmt.Printf(" and %d posts", postCount)
@@ -158,6 +267,7 @@ func pruneUsers(userRepo repositories.UserRepository, postRepo repositories.Post
return nil
}
if !*yes {
confirmMsg := fmt.Sprintf("\nAre you sure you want to permanently delete %d users", userCount)
if *deletePosts {
confirmMsg += fmt.Sprintf(" and %d posts", postCount)
@@ -174,6 +284,7 @@ func pruneUsers(userRepo repositories.UserRepository, postRepo repositories.Post
fmt.Println("Operation cancelled")
return nil
}
}
if *deletePosts {
totalDeleted, err := userRepo.HardDeleteAll()
@@ -198,6 +309,7 @@ func pruneUsers(userRepo repositories.UserRepository, postRepo repositories.Post
func pruneAll(userRepo repositories.UserRepository, postRepo repositories.PostRepository, args []string) error {
fs := flag.NewFlagSet("prune all", flag.ContinueOnError)
dryRun := fs.Bool("dry-run", false, "preview what would be deleted without actually deleting")
yes := fs.Bool("yes", false, "skip confirmation prompt")
fs.SetOutput(os.Stderr)
if err := fs.Parse(args); err != nil {
@@ -214,6 +326,30 @@ func pruneAll(userRepo repositories.UserRepository, postRepo repositories.PostRe
return fmt.Errorf("get post count: %w", err)
}
if IsJSONOutput() {
if *dryRun {
outputJSON(map[string]interface{}{
"action": "prune_all",
"dry_run": true,
"user_count": len(userCount),
"post_count": postCount,
})
return nil
}
if !*yes {
return fmt.Errorf("confirmation required. Use --yes to skip prompt or --json for non-interactive mode")
}
totalDeleted, err := userRepo.HardDeleteAll()
if err != nil {
return fmt.Errorf("hard delete all: %w", err)
}
outputJSON(map[string]interface{}{
"action": "prune_all",
"deleted_count": totalDeleted,
})
return nil
}
fmt.Printf("Found %d users and %d posts to delete\n", len(userCount), postCount)
if *dryRun {
@@ -221,6 +357,7 @@ func pruneAll(userRepo repositories.UserRepository, postRepo repositories.PostRe
return nil
}
if !*yes {
fmt.Printf("\nAre you sure you want to permanently delete ALL %d users and %d posts? (yes/no): ", len(userCount), postCount)
var confirmation string
if _, err := fmt.Scanln(&confirmation); err != nil {
@@ -231,6 +368,7 @@ func pruneAll(userRepo repositories.UserRepository, postRepo repositories.PostRe
fmt.Println("Operation cancelled")
return nil
}
}
totalDeleted, err := userRepo.HardDeleteAll()
if err != nil {

View File

@@ -110,6 +110,16 @@ func TestPrunePosts(t *testing.T) {
t.Errorf("prunePosts() with dry-run error = %v", err)
}
t.Run("prunePosts with JSON output", func(t *testing.T) {
SetJSONOutput(true)
defer SetJSONOutput(false)
err := prunePosts(postRepo, []string{"--dry-run"})
if err != nil {
t.Errorf("prunePosts() with dry-run and JSON output error = %v", err)
}
})
post1 := database.Post{
ID: 1,
Title: "Post by deleted user 1",
@@ -138,6 +148,16 @@ func TestPrunePosts(t *testing.T) {
if err != nil {
t.Errorf("prunePosts() with dry-run error = %v", err)
}
t.Run("prunePosts with JSON output and mock data", func(t *testing.T) {
SetJSONOutput(true)
defer SetJSONOutput(false)
err := prunePosts(postRepo, []string{"--dry-run"})
if err != nil {
t.Errorf("prunePosts() with dry-run and JSON output error = %v", err)
}
})
}
func TestPruneAll(t *testing.T) {
@@ -175,6 +195,16 @@ func TestPruneAll(t *testing.T) {
if err != nil {
t.Errorf("pruneAll() with dry-run error = %v", err)
}
t.Run("pruneAll with JSON output", func(t *testing.T) {
SetJSONOutput(true)
defer SetJSONOutput(false)
err := pruneAll(userRepo, postRepo, []string{"--dry-run"})
if err != nil {
t.Errorf("pruneAll() with dry-run and JSON output error = %v", err)
}
})
}
func TestPrunePostsWithError(t *testing.T) {

View File

@@ -1,20 +1,40 @@
package commands
import (
"crypto/rand"
cryptoRand "crypto/rand"
"errors"
"flag"
"fmt"
"math/big"
"math/rand"
"os"
"sync"
"time"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"goyco/internal/config"
"goyco/internal/database"
"goyco/internal/repositories"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
var (
seedRandSource *rand.Rand
seedRandOnce sync.Once
)
func initSeedRand() {
seedRandOnce.Do(func() {
seed := time.Now().UnixNano()
seedBytes := make([]byte, 8)
if _, err := cryptoRand.Read(seedBytes); err == nil {
seed = int64(seedBytes[0])<<56 | int64(seedBytes[1])<<48 | int64(seedBytes[2])<<40 | int64(seedBytes[3])<<32 |
int64(seedBytes[4])<<24 | int64(seedBytes[5])<<16 | int64(seedBytes[6])<<8 | int64(seedBytes[7])
}
seedRandSource = rand.New(rand.NewSource(seed))
})
}
func HandleSeedCommand(cfg *config.Config, name string, args []string) error {
fs := newFlagSet(name, printSeedUsage)
if err := parseCommand(fs, args, name); err != nil {
@@ -69,285 +89,231 @@ func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.Po
return err
}
originalUsers := *numUsers
originalPosts := *numPosts
originalVotesPerPost := *votesPerPost
if *numUsers < 0 {
if !IsJSONOutput() {
fmt.Fprintf(os.Stderr, "Warning: --users value %d is negative, clamping to 0\n", *numUsers)
}
*numUsers = 0
}
if *numPosts <= 0 {
if !IsJSONOutput() {
fmt.Fprintf(os.Stderr, "Warning: --posts value %d is too low, clamping to 1\n", *numPosts)
}
*numPosts = 1
}
if *votesPerPost < 0 {
if !IsJSONOutput() {
fmt.Fprintf(os.Stderr, "Warning: --votes-per-post value %d is negative, clamping to 0\n", *votesPerPost)
}
*votesPerPost = 0
}
if !IsJSONOutput() && (originalUsers != *numUsers || originalPosts != *numPosts || originalVotesPerPost != *votesPerPost) {
fmt.Fprintf(os.Stderr, "Using clamped values: --users=%d --posts=%d --votes-per-post=%d\n", *numUsers, *numPosts, *votesPerPost)
}
if !IsJSONOutput() {
fmt.Println("Starting database seeding...")
}
seedPassword := "seed-password"
userPassword := "password123"
seedPasswordHash, err := bcrypt.GenerateFromPassword([]byte(seedPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("precompute seed password hash: %w", err)
}
userPasswordHash, err := bcrypt.GenerateFromPassword([]byte(userPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("precompute user password hash: %w", err)
}
spinner := NewSpinner("Creating seed user")
if !IsJSONOutput() {
spinner.Spin()
}
seedUser, err := ensureSeedUser(userRepo)
seedUser, err := ensureSeedUser(userRepo, string(seedPasswordHash))
if err != nil {
if !IsJSONOutput() {
spinner.Complete()
}
return fmt.Errorf("ensure seed user: %w", err)
}
if !IsJSONOutput() {
spinner.Complete()
fmt.Printf("Seed user ready: ID=%d Username=%s\n", seedUser.ID, seedUser.Username)
}
processor := NewParallelProcessor()
processor.SetPasswordHash(string(userPasswordHash))
progress := NewProgressIndicator(*numUsers, "Creating users (parallel)")
var progress *ProgressIndicator
if !IsJSONOutput() && *numUsers > 0 {
progress = NewProgressIndicator(*numUsers, "Creating users (parallel)")
}
users, err := processor.CreateUsersInParallel(userRepo, *numUsers, progress)
if err != nil {
return fmt.Errorf("create random users: %w", err)
}
if !IsJSONOutput() && progress != nil {
progress.Complete()
}
allUsers := append([]database.User{*seedUser}, users...)
if !IsJSONOutput() && *numPosts > 0 {
progress = NewProgressIndicator(*numPosts, "Creating posts (parallel)")
}
posts, err := processor.CreatePostsInParallel(postRepo, seedUser.ID, *numPosts, progress)
if err != nil {
return fmt.Errorf("create random posts: %w", err)
}
if !IsJSONOutput() && progress != nil {
progress.Complete()
}
if !IsJSONOutput() && len(posts) > 0 {
progress = NewProgressIndicator(len(posts), "Creating votes (parallel)")
}
votes, err := processor.CreateVotesInParallel(voteRepo, allUsers, posts, *votesPerPost, progress)
if err != nil {
return fmt.Errorf("create random votes: %w", err)
}
if !IsJSONOutput() && progress != nil {
progress.Complete()
}
if !IsJSONOutput() && len(posts) > 0 {
progress = NewProgressIndicator(len(posts), "Updating scores (parallel)")
}
err = processor.UpdatePostScoresInParallel(postRepo, voteRepo, posts, progress)
if err != nil {
return fmt.Errorf("update post scores: %w", err)
}
if !IsJSONOutput() && progress != nil {
progress.Complete()
}
if err := validateSeedConsistency(voteRepo, allUsers, posts); err != nil {
return fmt.Errorf("seed consistency validation failed: %w", err)
}
if IsJSONOutput() {
outputJSON(map[string]any{
"action": "seed_completed",
"users": len(allUsers),
"posts": len(posts),
"votes": votes,
"seed_user": map[string]any{
"id": seedUser.ID,
"username": seedUser.Username,
},
})
} else {
fmt.Println("Database seeding completed successfully!")
fmt.Printf("Created %d users, %d posts, and %d votes\n", len(allUsers), len(posts), votes)
}
return nil
}
func ensureSeedUser(userRepo repositories.UserRepository) (*database.User, error) {
seedUsername := "seed_admin"
seedEmail := "seed_admin@goyco.local"
seedPassword := "seed-password"
const (
seedUsername = "seed_admin"
seedEmail = "seed_admin@goyco.local"
)
user, err := userRepo.GetByEmail(seedEmail)
if err == nil {
func ensureSeedUser(userRepo repositories.UserRepository, passwordHash string) (*database.User, error) {
if user, err := userRepo.GetByUsername(seedUsername); err == nil {
return user, nil
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(seedPassword), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("hash password: %w", err)
}
user = &database.User{
Username: seedUsername,
Email: seedEmail,
Password: string(hashedPassword),
EmailVerified: true,
}
if err := userRepo.Create(user); err != nil {
return nil, fmt.Errorf("create seed user: %w", err)
}
return user, nil
}
func createRandomUsers(userRepo repositories.UserRepository, count int) ([]database.User, error) {
var users []database.User
for i := range count {
username := fmt.Sprintf("user_%d", i+1)
email := fmt.Sprintf("user_%d@goyco.local", i+1)
password := "password123"
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("hash password for user %d: %w", i+1, err)
}
user := &database.User{
Username: username,
Email: email,
Password: string(hashedPassword),
Username: seedUsername,
Email: seedEmail,
Password: passwordHash,
EmailVerified: true,
}
if err := userRepo.Create(user); err != nil {
return nil, fmt.Errorf("create user %d: %w", i+1, err)
return nil, fmt.Errorf("failed to create seed user: %w", err)
}
users = append(users, *user)
}
return users, nil
return user, nil
}
func createRandomPosts(postRepo repositories.PostRepository, authorID uint, count int) ([]database.Post, error) {
var posts []database.Post
sampleTitles := []string{
"Amazing JavaScript Framework",
"Python Best Practices",
"Go Performance Tips",
"Database Optimization",
"Web Security Guide",
"Machine Learning Basics",
"Cloud Architecture",
"DevOps Automation",
"API Design Patterns",
"Frontend Optimization",
"Backend Scaling",
"Container Orchestration",
"Microservices Architecture",
"Testing Strategies",
"Code Review Process",
"Version Control Best Practices",
"Continuous Integration",
"Monitoring and Alerting",
"Error Handling Patterns",
"Data Structures Explained",
}
sampleDomains := []string{
"example.com",
"techblog.org",
"devguide.net",
"programming.io",
"codeexamples.com",
"tutorialhub.org",
"bestpractices.dev",
"learnprogramming.net",
"codingtips.org",
"softwareengineering.com",
}
for i := range count {
title := sampleTitles[i%len(sampleTitles)]
if i >= len(sampleTitles) {
title = fmt.Sprintf("%s - Part %d", title, (i/len(sampleTitles))+1)
}
domain := sampleDomains[i%len(sampleDomains)]
path := generateRandomPath()
url := fmt.Sprintf("https://%s%s", domain, path)
content := fmt.Sprintf("Autogenerated seed post #%d\n\nThis is sample content for testing purposes. The post discusses %s and provides valuable insights.", i+1, title)
post := &database.Post{
Title: title,
URL: url,
Content: content,
AuthorID: &authorID,
UpVotes: 0,
DownVotes: 0,
Score: 0,
}
if err := postRepo.Create(post); err != nil {
return nil, fmt.Errorf("create post %d: %w", i+1, err)
}
posts = append(posts, *post)
}
return posts, nil
func getVoteCounts(voteRepo repositories.VoteRepository, postID uint) (int, int, error) {
return voteRepo.GetVoteCountsByPostID(postID)
}
func generateRandomPath() string {
pathLength, _ := rand.Int(rand.Reader, big.NewInt(20))
path := "/article/"
for i := int64(0); i < pathLength.Int64()+5; i++ {
randomChar, _ := rand.Int(rand.Reader, big.NewInt(26))
path += string(rune('a' + randomChar.Int64()))
func validateSeedConsistency(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post) error {
userIDSet := make(map[uint]struct{}, len(users))
for _, user := range users {
userIDSet[user.ID] = struct{}{}
}
return path
}
func createRandomVotes(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post, avgVotesPerPost int) (int, error) {
totalVotes := 0
postIDSet := make(map[uint]struct{}, len(posts))
for _, post := range posts {
postIDSet[post.ID] = struct{}{}
}
for _, post := range posts {
voteCount, _ := rand.Int(rand.Reader, big.NewInt(int64(avgVotesPerPost*2)+1))
numVotes := int(voteCount.Int64())
if numVotes == 0 && avgVotesPerPost > 0 {
chance, _ := rand.Int(rand.Reader, big.NewInt(5))
if chance.Int64() > 0 {
numVotes = 1
}
if err := validatePost(post, userIDSet); err != nil {
return err
}
usedUsers := make(map[uint]bool)
for i := 0; i < numVotes && len(usedUsers) < len(users); i++ {
userIdx, _ := rand.Int(rand.Reader, big.NewInt(int64(len(users))))
user := users[userIdx.Int64()]
if usedUsers[user.ID] {
continue
}
usedUsers[user.ID] = true
voteTypeInt, _ := rand.Int(rand.Reader, big.NewInt(10))
var voteType database.VoteType
if voteTypeInt.Int64() < 7 {
voteType = database.VoteUp
} else {
voteType = database.VoteDown
}
vote := &database.Vote{
UserID: &user.ID,
PostID: post.ID,
Type: voteType,
}
if err := voteRepo.Create(vote); err != nil {
return totalVotes, fmt.Errorf("create vote for post %d: %w", post.ID, err)
}
totalVotes++
}
}
return totalVotes, nil
}
func updatePostScores(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post) error {
for _, post := range posts {
upVotes, downVotes, err := getVoteCounts(voteRepo, post.ID)
votes, err := voteRepo.GetByPostID(post.ID)
if err != nil {
return fmt.Errorf("get vote counts for post %d: %w", post.ID, err)
return fmt.Errorf("failed to retrieve votes for post %d: %w", post.ID, err)
}
post.UpVotes = upVotes
post.DownVotes = downVotes
post.Score = upVotes - downVotes
if err := postRepo.Update(&post); err != nil {
return fmt.Errorf("update post %d scores: %w", post.ID, err)
if err := validateVotesForPost(post.ID, votes, userIDSet, postIDSet); err != nil {
return err
}
}
return nil
}
func getVoteCounts(voteRepo repositories.VoteRepository, postID uint) (int, int, error) {
votes, err := voteRepo.GetByPostID(postID)
if err != nil {
return 0, 0, err
func validatePost(post database.Post, userIDSet map[uint]struct{}) error {
if post.AuthorID == nil {
return fmt.Errorf("post %d has no author ID", post.ID)
}
upVotes := 0
downVotes := 0
for _, vote := range votes {
switch vote.Type {
case database.VoteUp:
upVotes++
case database.VoteDown:
downVotes++
}
if _, exists := userIDSet[*post.AuthorID]; !exists {
return fmt.Errorf("post %d references non-existent author ID %d", post.ID, *post.AuthorID)
}
return upVotes, downVotes, nil
return nil
}
func validateVotesForPost(postID uint, votes []database.Vote, userIDSet map[uint]struct{}, postIDSet map[uint]struct{}) error {
for _, vote := range votes {
if vote.PostID != postID {
return fmt.Errorf("vote %d references post ID %d but was retrieved for post %d", vote.ID, vote.PostID, postID)
}
if _, exists := postIDSet[vote.PostID]; !exists {
return fmt.Errorf("vote %d references non-existent post ID %d", vote.ID, vote.PostID)
}
if vote.UserID != nil {
if _, exists := userIDSet[*vote.UserID]; !exists {
return fmt.Errorf("vote %d references non-existent user ID %d", vote.ID, *vote.UserID)
}
}
if vote.Type != database.VoteUp && vote.Type != database.VoteDown {
return fmt.Errorf("vote %d has invalid type %q", vote.ID, vote.Type)
}
}
return nil
}

View File

@@ -1,13 +1,16 @@
package commands
import (
"fmt"
"strings"
"testing"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"goyco/internal/database"
"goyco/internal/repositories"
"goyco/internal/testutils"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
func TestSeedCommand(t *testing.T) {
@@ -21,39 +24,64 @@ func TestSeedCommand(t *testing.T) {
t.Fatalf("Failed to migrate database: %v", err)
}
err = db.Transaction(func(tx *gorm.DB) error {
userRepo := repositories.NewUserRepository(db).WithTx(tx)
postRepo := repositories.NewPostRepository(db).WithTx(tx)
voteRepo := repositories.NewVoteRepository(db).WithTx(tx)
return seedDatabase(userRepo, postRepo, voteRepo, []string{"--users", "2", "--posts", "5", "--votes-per-post", "3"})
})
if err != nil {
t.Fatalf("Failed to seed database: %v", err)
}
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
voteRepo := repositories.NewVoteRepository(db)
seedUser, err := ensureSeedUser(userRepo)
users, err := userRepo.GetAll(100, 0)
if err != nil {
t.Fatalf("Failed to ensure seed user: %v", err)
t.Fatalf("Failed to get users: %v", err)
}
seedUserCount := 0
var seedUser *database.User
regularUserCount := 0
for i := range users {
if users[i].Username == "seed_admin" {
seedUserCount++
seedUser = &users[i]
} else if strings.HasPrefix(users[i].Username, "user_") {
regularUserCount++
}
}
if seedUserCount != 1 {
t.Errorf("Expected 1 seed user, got %d", seedUserCount)
}
if seedUser == nil {
t.Fatal("Expected seed user to be created")
}
if seedUser.Username != "seed_admin" {
t.Errorf("Expected username 'seed_admin', got '%s'", seedUser.Username)
t.Errorf("Expected username to be 'seed_admin', got '%s'", seedUser.Username)
}
if seedUser.Email != "seed_admin@goyco.local" {
t.Errorf("Expected email 'seed_admin@goyco.local', got '%s'", seedUser.Email)
t.Errorf("Expected email to be 'seed_admin@goyco.local', got '%s'", seedUser.Email)
}
if !seedUser.EmailVerified {
t.Error("Expected seed user to be email verified")
}
users, err := createRandomUsers(userRepo, 2)
if err != nil {
t.Fatalf("Failed to create random users: %v", err)
if regularUserCount != 2 {
t.Errorf("Expected 2 regular users, got %d", regularUserCount)
}
if len(users) != 2 {
t.Errorf("Expected 2 users, got %d", len(users))
}
posts, err := createRandomPosts(postRepo, seedUser.ID, 5)
posts, err := postRepo.GetAll(100, 0)
if err != nil {
t.Fatalf("Failed to create random posts: %v", err)
t.Fatalf("Failed to get posts: %v", err)
}
if len(posts) != 5 {
@@ -70,39 +98,49 @@ func TestSeedCommand(t *testing.T) {
if post.AuthorID == nil || *post.AuthorID != seedUser.ID {
t.Errorf("Post %d has wrong author ID: expected %d, got %v", i, seedUser.ID, post.AuthorID)
}
expectedScore := post.UpVotes - post.DownVotes
if post.Score != expectedScore {
t.Errorf("Post %d has incorrect score: expected %d, got %d", i, expectedScore, post.Score)
}
}
allUsers := append([]database.User{*seedUser}, users...)
votes, err := createRandomVotes(voteRepo, allUsers, posts, 3)
voteCount, err := voteRepo.Count()
if err != nil {
t.Fatalf("Failed to create random votes: %v", err)
t.Fatalf("Failed to count votes: %v", err)
}
if votes == 0 {
if voteCount == 0 {
t.Error("Expected some votes to be created")
}
err = updatePostScores(postRepo, voteRepo, posts)
for _, post := range posts {
postVotes, err := voteRepo.GetByPostID(post.ID)
if err != nil {
t.Fatalf("Failed to update post scores: %v", err)
}
for i, post := range posts {
updatedPost, err := postRepo.GetByID(post.ID)
if err != nil {
t.Errorf("Failed to get updated post %d: %v", i, err)
t.Errorf("Failed to get votes for post %d: %v", post.ID, err)
continue
}
expectedScore := updatedPost.UpVotes - updatedPost.DownVotes
if updatedPost.Score != expectedScore {
t.Errorf("Post %d has incorrect score: expected %d, got %d", i, expectedScore, updatedPost.Score)
for _, vote := range postVotes {
if vote.PostID != post.ID {
t.Errorf("Vote has wrong post ID: expected %d, got %d", post.ID, vote.PostID)
}
if vote.UserID == nil {
t.Error("Vote has nil user ID")
}
}
}
}
func TestGenerateRandomPath(t *testing.T) {
path := generateRandomPath()
initSeedRand()
pathLength := seedRandSource.Intn(20)
path := "/article/"
for i := 0; i < pathLength+5; i++ {
randomChar := seedRandSource.Intn(26)
path += string(rune('a' + randomChar))
}
if path == "" {
t.Error("Generated path should not be empty")
@@ -112,7 +150,14 @@ func TestGenerateRandomPath(t *testing.T) {
t.Errorf("Generated path too short: %s", path)
}
secondPath := generateRandomPath()
initSeedRand()
secondPathLength := seedRandSource.Intn(20)
secondPath := "/article/"
for i := 0; i < secondPathLength+5; i++ {
randomChar := seedRandSource.Intn(26)
secondPath += string(rune('a' + randomChar))
}
if path == secondPath {
t.Error("Generated paths should be different")
}
@@ -178,4 +223,311 @@ func TestSeedDatabaseFlagParsing(t *testing.T) {
t.Error("expected error for missing votes-per-post value")
}
})
t.Run("negative users value is clamped", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--users", "-1", "--posts", "1"})
if err != nil {
t.Errorf("negative users should be clamped, not rejected. Got error: %v", err)
}
})
t.Run("negative posts value is clamped", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--posts", "-5"})
if err != nil {
t.Errorf("negative posts should be clamped, not rejected. Got error: %v", err)
}
})
t.Run("zero posts value is clamped", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--posts", "0"})
if err != nil {
t.Errorf("zero posts should be clamped, not rejected. Got error: %v", err)
}
})
t.Run("negative votes-per-post value is clamped", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--votes-per-post", "-10", "--posts", "1"})
if err != nil {
t.Errorf("negative votes-per-post should be clamped, not rejected. Got error: %v", err)
}
})
t.Run("zero users value is valid", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--users", "0", "--posts", "1"})
if err != nil {
t.Errorf("zero users should be valid, got error: %v", err)
}
})
t.Run("zero votes-per-post value is valid", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--votes-per-post", "0", "--posts", "1"})
if err != nil {
t.Errorf("zero votes-per-post should be valid, got error: %v", err)
}
})
}
func TestSeedCommandIdempotency(t *testing.T) {
dbName := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())
db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{})
if err != nil {
t.Fatalf("Failed to connect to database: %v", err)
}
err = db.AutoMigrate(&database.User{}, &database.Post{}, &database.Vote{})
if err != nil {
t.Fatalf("Failed to migrate database: %v", err)
}
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
voteRepo := repositories.NewVoteRepository(db)
t.Run("first run creates seed user", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--users", "1", "--posts", "2"})
if err != nil {
t.Fatalf("First seed run failed: %v", err)
}
users, err := userRepo.GetAll(100, 0)
if err != nil {
t.Fatalf("Failed to get users: %v", err)
}
seedUserCount := 0
for _, user := range users {
if user.Username == "seed_admin" {
seedUserCount++
}
}
if seedUserCount != 1 {
t.Errorf("Expected exactly 1 seed user, got %d", seedUserCount)
}
})
t.Run("second run reuses seed user", func(t *testing.T) {
usersBefore, _ := userRepo.GetAll(100, 0)
seedUserBefore := findSeedUser(usersBefore)
if seedUserBefore == nil {
t.Fatal("No seed user found before second run")
}
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--users", "1", "--posts", "2"})
if err != nil {
t.Fatalf("Second seed run failed: %v", err)
}
usersAfter, _ := userRepo.GetAll(100, 0)
seedUserAfter := findSeedUser(usersAfter)
if seedUserAfter == nil {
t.Fatal("Seed user not found after second run")
}
if seedUserBefore.ID != seedUserAfter.ID {
t.Errorf("Expected seed user to be reused (ID %d), but got different user (ID %d)", seedUserBefore.ID, seedUserAfter.ID)
}
})
t.Run("database remains consistent after multiple runs", func(t *testing.T) {
for i := range 2 {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--users", "0", "--posts", "1"})
if err != nil {
t.Fatalf("Seed run %d failed: %v", i+1, err)
}
}
users, _ := userRepo.GetAll(100, 0)
posts, _ := postRepo.GetAll(100, 0)
for _, post := range posts {
if post.AuthorID == nil {
t.Errorf("Post %d has no author", post.ID)
continue
}
authorExists := false
for _, user := range users {
if user.ID == *post.AuthorID {
authorExists = true
break
}
}
if !authorExists {
t.Errorf("Post %d has invalid author ID %d", post.ID, *post.AuthorID)
}
votes, _ := voteRepo.GetByPostID(post.ID)
for _, vote := range votes {
if vote.UserID != nil {
userExists := false
for _, user := range users {
if user.ID == *vote.UserID {
userExists = true
break
}
}
if !userExists {
t.Errorf("Vote %d has invalid user ID %d", vote.ID, *vote.UserID)
}
}
}
}
})
}
func findSeedUser(users []database.User) *database.User {
for i := range users {
if users[i].Username == "seed_admin" {
return &users[i]
}
}
return nil
}
func TestSeedCommandTransactionRollback(t *testing.T) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("Failed to connect to database: %v", err)
}
err = db.AutoMigrate(&database.User{}, &database.Post{}, &database.Vote{})
if err != nil {
t.Fatalf("Failed to migrate database: %v", err)
}
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
voteRepo := repositories.NewVoteRepository(db)
t.Run("transaction rolls back on failure", func(t *testing.T) {
initialUserCount, _ := userRepo.Count()
initialPostCount, _ := postRepo.Count()
initialVoteCount, _ := voteRepo.Count()
err := db.Transaction(func(tx *gorm.DB) error {
txUserRepo := userRepo.WithTx(tx)
txPostRepo := postRepo.WithTx(tx)
txVoteRepo := voteRepo.WithTx(tx)
err := seedDatabase(txUserRepo, txPostRepo, txVoteRepo, []string{"--users", "2", "--posts", "3"})
if err != nil {
return err
}
return fmt.Errorf("simulated failure")
})
if err == nil {
t.Fatal("Expected transaction to fail")
}
finalUserCount, _ := userRepo.Count()
finalPostCount, _ := postRepo.Count()
finalVoteCount, _ := voteRepo.Count()
if finalUserCount != initialUserCount {
t.Errorf("Expected user count to remain %d after rollback, got %d", initialUserCount, finalUserCount)
}
if finalPostCount != initialPostCount {
t.Errorf("Expected post count to remain %d after rollback, got %d", initialPostCount, finalPostCount)
}
if finalVoteCount != initialVoteCount {
t.Errorf("Expected vote count to remain %d after rollback, got %d", initialVoteCount, finalVoteCount)
}
})
t.Run("transaction commits on success", func(t *testing.T) {
initialUserCount, _ := userRepo.Count()
initialPostCount, _ := postRepo.Count()
err := db.Transaction(func(tx *gorm.DB) error {
txUserRepo := userRepo.WithTx(tx)
txPostRepo := postRepo.WithTx(tx)
txVoteRepo := voteRepo.WithTx(tx)
return seedDatabase(txUserRepo, txPostRepo, txVoteRepo, []string{"--users", "1", "--posts", "1"})
})
if err != nil {
t.Fatalf("Expected transaction to succeed, got error: %v", err)
}
finalUserCount, _ := userRepo.Count()
finalPostCount, _ := postRepo.Count()
expectedUsers := initialUserCount + 2
expectedPosts := initialPostCount + 1
if finalUserCount < expectedUsers {
t.Errorf("Expected at least %d users after commit, got %d", expectedUsers, finalUserCount)
}
if finalPostCount < expectedPosts {
t.Errorf("Expected at least %d posts after commit, got %d", expectedPosts, finalPostCount)
}
})
}
func TestEnsureSeedUser(t *testing.T) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("Failed to connect to database: %v", err)
}
if err := db.AutoMigrate(&database.User{}); err != nil {
t.Fatalf("Failed to migrate database: %v", err)
}
userRepo := repositories.NewUserRepository(db)
passwordHash := "test_password_hash"
firstUser, err := ensureSeedUser(userRepo, passwordHash)
if err != nil {
t.Fatalf("Failed to create seed user: %v", err)
}
if firstUser.Username != "seed_admin" || firstUser.Email != "seed_admin@goyco.local" || firstUser.Password != passwordHash || !firstUser.EmailVerified {
t.Errorf("Invalid seed user: username=%s, email=%s, password matches=%v, emailVerified=%v",
firstUser.Username, firstUser.Email, firstUser.Password == passwordHash, firstUser.EmailVerified)
}
secondUser, err := ensureSeedUser(userRepo, "different_password_hash")
if err != nil {
t.Fatalf("Failed to reuse seed user: %v", err)
}
if firstUser.ID != secondUser.ID {
t.Errorf("Expected same user to be reused (ID %d), got different user (ID %d)", firstUser.ID, secondUser.ID)
}
for i := 0; i < 3; i++ {
if _, err := ensureSeedUser(userRepo, passwordHash); err != nil {
t.Fatalf("Call %d failed: %v", i+1, err)
}
}
users, err := userRepo.GetAll(100, 0)
if err != nil {
t.Fatalf("Failed to get users: %v", err)
}
seedUserCount := 0
for _, user := range users {
if user.Username == "seed_admin" {
seedUserCount++
}
}
if seedUserCount != 1 {
t.Errorf("Expected exactly 1 seed user, got %d", seedUserCount)
}
}

View File

@@ -98,7 +98,7 @@ func userCreate(cfg *config.Config, repo repositories.UserRepository, args []str
auditLogger, err := NewAuditLogger(cfg.LogDir)
if err != nil {
fmt.Printf("Warning: Could not initialize audit logging: %v\n", err)
outputWarning("Could not initialize audit logging: %v", err)
auditLogger = nil
}
@@ -168,7 +168,16 @@ func userCreate(cfg *config.Config, repo repositories.UserRepository, args []str
auditLogger.LogUserCreation(user.ID, user.Username, user.Email, true, nil)
}
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"action": "user_created",
"id": user.ID,
"username": user.Username,
"email": user.Email,
})
} else {
fmt.Printf("User created: %s (%s)\n", user.Username, user.Email)
}
return nil
}
@@ -286,7 +295,16 @@ func userUpdate(cfg *config.Config, repo repositories.UserRepository, refreshTok
return handleDatabaseConstraintError(err)
}
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"action": "user_updated",
"id": user.ID,
"username": user.Username,
"email": user.Email,
})
} else {
fmt.Printf("User updated: %s (%s)\n", user.Username, user.Email)
}
return nil
}
@@ -383,16 +401,13 @@ func userDelete(cfg *config.Config, repo repositories.UserRepository, args []str
}
var deleteErr error
var postsDeleted bool
if *deletePosts {
deleteErr = repo.HardDelete(uint(id))
if deleteErr == nil {
fmt.Printf("User deleted: ID=%d (posts also deleted)\n", id)
}
postsDeleted = true
} else {
deleteErr = repo.SoftDeleteWithPosts(uint(id))
if deleteErr == nil {
fmt.Printf("User deleted: ID=%d (posts kept)\n", id)
}
postsDeleted = false
}
if deleteErr != nil {
@@ -402,11 +417,31 @@ func userDelete(cfg *config.Config, repo repositories.UserRepository, args []str
emailSender := services.NewSMTPSenderWithTimeout(cfg.SMTP.Host, cfg.SMTP.Port, cfg.SMTP.Username, cfg.SMTP.Password, cfg.SMTP.From, cfg.SMTP.Timeout)
subject, body := services.GenerateAdminAccountDeletionNotificationEmail(user.Username, cfg.App.AdminEmail, cfg.App.BaseURL, cfg.App.Title, *deletePosts)
emailSent := true
if err := emailSender.Send(user.Email, subject, body); err != nil {
fmt.Printf("Warning: Could not send notification email to %s: %v\n", user.Email, err)
outputWarning("Could not send notification email to %s: %v", user.Email, err)
emailSent = false
}
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"action": "user_deleted",
"id": id,
"username": user.Username,
"email": user.Email,
"posts_deleted": postsDeleted,
"email_sent": emailSent,
})
} else {
if postsDeleted {
fmt.Printf("User deleted: ID=%d (posts also deleted)\n", id)
} else {
fmt.Printf("User deleted: ID=%d (posts kept)\n", id)
}
if emailSent {
fmt.Printf("Notification email sent to %s\n", user.Email)
}
}
return nil
}
@@ -426,6 +461,31 @@ func userList(repo repositories.UserRepository, args []string) error {
return fmt.Errorf("list users: %w", err)
}
if IsJSONOutput() {
type userJSON struct {
ID uint `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
Locked bool `json:"locked"`
CreatedAt string `json:"created_at"`
}
usersJSON := make([]userJSON, len(users))
for i, u := range users {
usersJSON[i] = userJSON{
ID: u.ID,
Username: u.Username,
Email: u.Email,
Locked: u.Locked,
CreatedAt: u.CreatedAt.Format("2006-01-02 15:04:05"),
}
}
outputJSON(map[string]interface{}{
"users": usersJSON,
"count": len(usersJSON),
})
return nil
}
if len(users) == 0 {
fmt.Println("No users found")
return nil
@@ -517,7 +577,16 @@ func userLock(cfg *config.Config, repo repositories.UserRepository, args []strin
}
if user.Locked {
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"action": "user_lock",
"id": id,
"username": user.Username,
"status": "already_locked",
})
} else {
fmt.Printf("User is already locked: %s\n", user.Username)
}
return nil
}
@@ -525,16 +594,28 @@ func userLock(cfg *config.Config, repo repositories.UserRepository, args []strin
return fmt.Errorf("lock user: %w", err)
}
fmt.Printf("User locked: %s\n", user.Username)
emailSender := services.NewSMTPSenderWithTimeout(cfg.SMTP.Host, cfg.SMTP.Port, cfg.SMTP.Username, cfg.SMTP.Password, cfg.SMTP.From, cfg.SMTP.Timeout)
subject, body := services.GenerateAccountLockNotificationEmail(user.Username, cfg.App.AdminEmail, cfg.App.Title)
emailSent := true
if err := emailSender.Send(user.Email, subject, body); err != nil {
fmt.Printf("Warning: Could not send notification email to %s: %v\n", user.Email, err)
outputWarning("Could not send notification email to %s: %v", user.Email, err)
emailSent = false
}
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"action": "user_locked",
"id": id,
"username": user.Username,
"email_sent": emailSent,
})
} else {
fmt.Printf("User locked: %s\n", user.Username)
if emailSent {
fmt.Printf("Notification email sent to %s\n", user.Email)
}
}
return nil
}
@@ -571,7 +652,16 @@ func userUnlock(cfg *config.Config, repo repositories.UserRepository, args []str
}
if !user.Locked {
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"action": "user_unlock",
"id": id,
"username": user.Username,
"status": "already_unlocked",
})
} else {
fmt.Printf("User is already unlocked: %s\n", user.Username)
}
return nil
}
@@ -579,16 +669,28 @@ func userUnlock(cfg *config.Config, repo repositories.UserRepository, args []str
return fmt.Errorf("unlock user: %w", err)
}
fmt.Printf("User unlocked: %s\n", user.Username)
emailSender := services.NewSMTPSenderWithTimeout(cfg.SMTP.Host, cfg.SMTP.Port, cfg.SMTP.Username, cfg.SMTP.Password, cfg.SMTP.From, cfg.SMTP.Timeout)
subject, body := services.GenerateAccountUnlockNotificationEmail(user.Username, cfg.App.AdminEmail, cfg.App.BaseURL, cfg.App.Title)
emailSent := true
if err := emailSender.Send(user.Email, subject, body); err != nil {
fmt.Printf("Warning: Could not send notification email to %s: %v\n", user.Email, err)
outputWarning("Could not send notification email to %s: %v", user.Email, err)
emailSent = false
}
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"action": "user_unlocked",
"id": id,
"username": user.Username,
"email_sent": emailSent,
})
} else {
fmt.Printf("User unlocked: %s\n", user.Username)
if emailSent {
fmt.Printf("Notification email sent to %s\n", user.Email)
}
}
return nil
}
@@ -629,8 +731,18 @@ func resetUserPassword(cfg *config.Config, repo repositories.UserRepository, ses
return fmt.Errorf("send password reset email: %w", err)
}
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"action": "password_reset",
"id": userID,
"username": user.Username,
"email": user.Email,
"message": "Temporary password sent. User must change password on next login.",
})
} else {
fmt.Printf("Password reset for user %s: Temporary password sent to %s\n", user.Username, user.Email)
fmt.Printf("⚠️ User must change this password on next login!\n")
}
return nil
}

View File

@@ -81,6 +81,22 @@ func TestUserCreate(t *testing.T) {
}
})
t.Run("successful creation with JSON output", func(t *testing.T) {
SetJSONOutput(true)
defer SetJSONOutput(false)
mockRepo := testutils.NewMockUserRepository()
err := userCreate(cfg, mockRepo, []string{
"--username", "testuser",
"--email", "test@example.com",
"--password", "StrongPass123!",
})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("missing username", func(t *testing.T) {
mockRepo := testutils.NewMockUserRepository()
err := userCreate(cfg, mockRepo, []string{
@@ -239,6 +255,22 @@ func TestUserUpdate(t *testing.T) {
}
})
t.Run("successful update username with JSON output", func(t *testing.T) {
SetJSONOutput(true)
defer SetJSONOutput(false)
cfg := &config.Config{}
mockRefreshRepo := &mockRefreshTokenRepo{}
err := userUpdate(cfg, mockRepo, mockRefreshRepo, []string{
"1",
"--username", "newusername",
})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("successful update email", func(t *testing.T) {
cfg := &config.Config{}
mockRefreshRepo := &mockRefreshTokenRepo{}
@@ -351,6 +383,25 @@ func TestUserDelete(t *testing.T) {
}
})
t.Run("successful delete with JSON output", func(t *testing.T) {
SetJSONOutput(true)
defer SetJSONOutput(false)
testUser := &database.User{
Username: "testuser",
Email: "test@example.com",
Password: "hashedpassword",
}
freshMockRepo := testutils.NewMockUserRepository()
_ = freshMockRepo.Create(testUser)
err := userDelete(cfg, freshMockRepo, []string{"1"})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("successful delete with posts", func(t *testing.T) {
testUser2 := &database.User{
Username: "testuser2",
@@ -459,6 +510,17 @@ func TestUserList(t *testing.T) {
}
})
t.Run("list all users with JSON output", func(t *testing.T) {
SetJSONOutput(true)
defer SetJSONOutput(false)
err := userList(mockRepo, []string{})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("list with limit", func(t *testing.T) {
err := userList(mockRepo, []string{"--limit", "1"})

View File

@@ -172,11 +172,12 @@ func FuzzRunCommandHandler(f *testing.F) {
err := handleRunCommand(cfg, args)
if len(args) > 0 && args[0] == "--help" {
switch {
case len(args) > 0 && args[0] == "--help":
if err != nil {
t.Logf("Help flag should not error, got: %v", err)
}
} else if len(args) > 0 {
case len(args) > 0:
if err == nil {
return
}
@@ -190,7 +191,7 @@ func FuzzRunCommandHandler(f *testing.F) {
if !strings.Contains(errMsg, "unexpected arguments") {
t.Logf("Got error (may be acceptable for server setup): %v", err)
}
} else {
default:
if err != nil && strings.Contains(err.Error(), "unexpected arguments") {
t.Fatalf("Empty args should not trigger 'unexpected arguments' error: %v", err)
}

View File

@@ -67,6 +67,7 @@ func run(args []string) error {
rootFS.SetOutput(os.Stderr)
rootFS.Usage = printRootUsage
showHelp := rootFS.Bool("help", false, "show this help message")
jsonOutput := rootFS.Bool("json", cfg.CLI.JSONOutputDefault, "output results in JSON format")
if err := rootFS.Parse(args); err != nil {
if errors.Is(err, flag.ErrHelp) {
@@ -80,6 +81,8 @@ func run(args []string) error {
return nil
}
commands.SetJSONOutput(*jsonOutput)
remaining := rootFS.Args()
if len(remaining) == 0 {
printRootUsage()

View File

@@ -1,6 +1,7 @@
package main
import (
"context"
"crypto/tls"
"errors"
"flag"
@@ -95,16 +96,21 @@ func TestServerConfigurationFromConfig(t *testing.T) {
testServer := httptest.NewServer(srv.Handler)
defer testServer.Close()
resp, err := http.Get(testServer.URL + "/health")
request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
response, err := http.DefaultClient.Do(request)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
defer func() {
_ = resp.Body.Close()
_ = response.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
if response.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", response.StatusCode)
}
}
@@ -159,6 +165,7 @@ func TestTLSWiringFromConfig(t *testing.T) {
srv := &http.Server{
Addr: expectedAddr,
Handler: router,
ReadHeaderTimeout: 5 * time.Second,
}
if srv.Addr != expectedAddr {
@@ -201,24 +208,27 @@ func TestTLSWiringFromConfig(t *testing.T) {
},
}
resp, err := client.Get(testServer.URL + "/health")
request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
response, err := client.Do(request)
if err != nil {
t.Fatalf("Failed to make TLS request: %v", err)
}
defer func() {
_ = resp.Body.Close()
_ = response.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 over TLS, got %d", resp.StatusCode)
if response.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 over TLS, got %d", response.StatusCode)
}
if resp.TLS == nil {
if response.TLS == nil {
t.Error("Expected TLS connection info to be present in response")
} else {
if resp.TLS.Version < tls.VersionTLS12 {
t.Errorf("Expected TLS version 1.2 or higher, got %x", resp.TLS.Version)
}
} else if response.TLS.Version < tls.VersionTLS12 {
t.Errorf("Expected TLS version 1.2 or higher, got %x", response.TLS.Version)
}
}
}
@@ -358,28 +368,38 @@ func TestServerInitializationFlow(t *testing.T) {
testServer := httptest.NewServer(srv.Handler)
defer testServer.Close()
resp, err := http.Get(testServer.URL + "/health")
request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
response, err := http.DefaultClient.Do(request)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
defer func() {
_ = resp.Body.Close()
_ = response.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
if response.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", response.StatusCode)
}
resp, err = http.Get(testServer.URL + "/api")
request, err = http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/api", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
response, err = http.DefaultClient.Do(request)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
defer func() {
_ = resp.Body.Close()
_ = response.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 for API endpoint, got %d", resp.StatusCode)
if response.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 for API endpoint, got %d", response.StatusCode)
}
}

View File

@@ -1,6 +1,6 @@
services:
db:
image: postgres:17-alpine
image: postgres:18-alpine
restart: unless-stopped
env_file:
- ../.env

View File

@@ -30,7 +30,7 @@ services:
- goyco
db:
image: postgres:17-alpine
image: postgres:18-alpine
restart: always
env_file:
- ../.env

View File

@@ -111,7 +111,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ConfirmAccountDeletionRequest"
"$ref": "#/definitions/dto.ConfirmAccountDeletionRequest"
}
}
],
@@ -212,7 +212,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdateEmailRequest"
"$ref": "#/definitions/dto.UpdateEmailRequest"
}
}
],
@@ -276,7 +276,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ForgotPasswordRequest"
"$ref": "#/definitions/dto.ForgotPasswordRequest"
}
}
],
@@ -316,7 +316,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.LoginRequest"
"$ref": "#/definitions/dto.LoginRequest"
}
}
],
@@ -453,7 +453,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdatePasswordRequest"
"$ref": "#/definitions/dto.UpdatePasswordRequest"
}
}
],
@@ -505,7 +505,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RefreshTokenRequest"
"$ref": "#/definitions/dto.RefreshTokenRequest"
}
}
],
@@ -563,7 +563,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RegisterRequest"
"$ref": "#/definitions/dto.RegisterRequest"
}
}
],
@@ -615,7 +615,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ResendVerificationRequest"
"$ref": "#/definitions/dto.ResendVerificationRequest"
}
}
],
@@ -685,7 +685,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ResetPasswordRequest"
"$ref": "#/definitions/dto.ResetPasswordRequest"
}
}
],
@@ -736,7 +736,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RevokeTokenRequest"
"$ref": "#/definitions/dto.RevokeTokenRequest"
}
}
],
@@ -833,7 +833,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdateUsernameRequest"
"$ref": "#/definitions/dto.UpdateUsernameRequest"
}
}
],
@@ -945,7 +945,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.CreatePostRequest"
"$ref": "#/definitions/dto.CreatePostRequest"
}
}
],
@@ -1176,7 +1176,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdatePostRequest"
"$ref": "#/definitions/dto.UpdatePostRequest"
}
}
],
@@ -1370,7 +1370,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.VoteRequest"
"$ref": "#/definitions/dto.CastVoteRequest"
}
}
],
@@ -1601,7 +1601,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RegisterRequest"
"$ref": "#/definitions/dto.RegisterRequest"
}
}
],
@@ -1817,6 +1817,223 @@ const docTemplate = `{
}
},
"definitions": {
"dto.CastVoteRequest": {
"type": "object",
"required": [
"type"
],
"properties": {
"type": {
"type": "string",
"enum": [
"up",
"down",
"none"
]
}
}
},
"dto.ConfirmAccountDeletionRequest": {
"type": "object",
"required": [
"token"
],
"properties": {
"delete_posts": {
"type": "boolean"
},
"token": {
"type": "string"
}
}
},
"dto.CreatePostRequest": {
"type": "object",
"required": [
"url"
],
"properties": {
"content": {
"type": "string",
"maxLength": 10000
},
"title": {
"type": "string",
"maxLength": 200,
"minLength": 3
},
"url": {
"type": "string",
"maxLength": 2048
}
}
},
"dto.ForgotPasswordRequest": {
"type": "object",
"required": [
"username_or_email"
],
"properties": {
"username_or_email": {
"type": "string"
}
}
},
"dto.LoginRequest": {
"type": "object",
"required": [
"password",
"username"
],
"properties": {
"password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"dto.RefreshTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"dto.RegisterRequest": {
"type": "object",
"required": [
"email",
"password",
"username"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
},
"password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"dto.ResendVerificationRequest": {
"type": "object",
"required": [
"email"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
}
}
},
"dto.ResetPasswordRequest": {
"type": "object",
"required": [
"new_password",
"token"
],
"properties": {
"new_password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"token": {
"type": "string"
}
}
},
"dto.RevokeTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"dto.UpdateEmailRequest": {
"type": "object",
"required": [
"email"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
}
}
},
"dto.UpdatePasswordRequest": {
"type": "object",
"required": [
"current_password",
"new_password"
],
"properties": {
"current_password": {
"type": "string"
},
"new_password": {
"type": "string",
"maxLength": 128,
"minLength": 8
}
}
},
"dto.UpdatePostRequest": {
"type": "object",
"required": [
"title"
],
"properties": {
"content": {
"type": "string",
"maxLength": 10000
},
"title": {
"type": "string",
"maxLength": 200,
"minLength": 3
}
}
},
"dto.UpdateUsernameRequest": {
"type": "object",
"required": [
"username"
],
"properties": {
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"handlers.APIInfo": {
"type": "object",
"properties": {
@@ -1919,50 +2136,6 @@ const docTemplate = `{
}
}
},
"handlers.ConfirmAccountDeletionRequest": {
"type": "object",
"properties": {
"delete_posts": {
"type": "boolean"
},
"token": {
"type": "string"
}
}
},
"handlers.CreatePostRequest": {
"type": "object",
"properties": {
"content": {
"type": "string"
},
"title": {
"type": "string"
},
"url": {
"type": "string"
}
}
},
"handlers.ForgotPasswordRequest": {
"type": "object",
"properties": {
"username_or_email": {
"type": "string"
}
}
},
"handlers.LoginRequest": {
"type": "object",
"properties": {
"password": {
"type": "string"
},
"username": {
"type": "string"
}
}
},
"handlers.PostResponse": {
"type": "object",
"properties": {
@@ -1978,101 +2151,6 @@ const docTemplate = `{
}
}
},
"handlers.RefreshTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"handlers.RegisterRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
},
"password": {
"type": "string"
},
"username": {
"type": "string"
}
}
},
"handlers.ResendVerificationRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
}
}
},
"handlers.ResetPasswordRequest": {
"type": "object",
"properties": {
"new_password": {
"type": "string"
},
"token": {
"type": "string"
}
}
},
"handlers.RevokeTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"handlers.UpdateEmailRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
}
}
},
"handlers.UpdatePasswordRequest": {
"type": "object",
"properties": {
"current_password": {
"type": "string"
},
"new_password": {
"type": "string"
}
}
},
"handlers.UpdatePostRequest": {
"type": "object",
"properties": {
"content": {
"type": "string"
},
"title": {
"type": "string"
}
}
},
"handlers.UpdateUsernameRequest": {
"type": "object",
"properties": {
"username": {
"type": "string"
}
}
},
"handlers.UserResponse": {
"type": "object",
"properties": {
@@ -2088,21 +2166,6 @@ const docTemplate = `{
}
}
},
"handlers.VoteRequest": {
"description": "Vote request with type field. All votes are handled the same way.",
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": [
"up",
"down",
"none"
],
"example": "up"
}
}
},
"handlers.VoteResponse": {
"type": "object",
"properties": {

View File

@@ -108,7 +108,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ConfirmAccountDeletionRequest"
"$ref": "#/definitions/dto.ConfirmAccountDeletionRequest"
}
}
],
@@ -209,7 +209,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdateEmailRequest"
"$ref": "#/definitions/dto.UpdateEmailRequest"
}
}
],
@@ -273,7 +273,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ForgotPasswordRequest"
"$ref": "#/definitions/dto.ForgotPasswordRequest"
}
}
],
@@ -313,7 +313,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.LoginRequest"
"$ref": "#/definitions/dto.LoginRequest"
}
}
],
@@ -450,7 +450,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdatePasswordRequest"
"$ref": "#/definitions/dto.UpdatePasswordRequest"
}
}
],
@@ -502,7 +502,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RefreshTokenRequest"
"$ref": "#/definitions/dto.RefreshTokenRequest"
}
}
],
@@ -560,7 +560,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RegisterRequest"
"$ref": "#/definitions/dto.RegisterRequest"
}
}
],
@@ -612,7 +612,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ResendVerificationRequest"
"$ref": "#/definitions/dto.ResendVerificationRequest"
}
}
],
@@ -682,7 +682,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.ResetPasswordRequest"
"$ref": "#/definitions/dto.ResetPasswordRequest"
}
}
],
@@ -733,7 +733,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RevokeTokenRequest"
"$ref": "#/definitions/dto.RevokeTokenRequest"
}
}
],
@@ -830,7 +830,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdateUsernameRequest"
"$ref": "#/definitions/dto.UpdateUsernameRequest"
}
}
],
@@ -942,7 +942,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.CreatePostRequest"
"$ref": "#/definitions/dto.CreatePostRequest"
}
}
],
@@ -1173,7 +1173,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.UpdatePostRequest"
"$ref": "#/definitions/dto.UpdatePostRequest"
}
}
],
@@ -1367,7 +1367,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.VoteRequest"
"$ref": "#/definitions/dto.CastVoteRequest"
}
}
],
@@ -1598,7 +1598,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.RegisterRequest"
"$ref": "#/definitions/dto.RegisterRequest"
}
}
],
@@ -1814,6 +1814,223 @@
}
},
"definitions": {
"dto.CastVoteRequest": {
"type": "object",
"required": [
"type"
],
"properties": {
"type": {
"type": "string",
"enum": [
"up",
"down",
"none"
]
}
}
},
"dto.ConfirmAccountDeletionRequest": {
"type": "object",
"required": [
"token"
],
"properties": {
"delete_posts": {
"type": "boolean"
},
"token": {
"type": "string"
}
}
},
"dto.CreatePostRequest": {
"type": "object",
"required": [
"url"
],
"properties": {
"content": {
"type": "string",
"maxLength": 10000
},
"title": {
"type": "string",
"maxLength": 200,
"minLength": 3
},
"url": {
"type": "string",
"maxLength": 2048
}
}
},
"dto.ForgotPasswordRequest": {
"type": "object",
"required": [
"username_or_email"
],
"properties": {
"username_or_email": {
"type": "string"
}
}
},
"dto.LoginRequest": {
"type": "object",
"required": [
"password",
"username"
],
"properties": {
"password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"dto.RefreshTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"dto.RegisterRequest": {
"type": "object",
"required": [
"email",
"password",
"username"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
},
"password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"dto.ResendVerificationRequest": {
"type": "object",
"required": [
"email"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
}
}
},
"dto.ResetPasswordRequest": {
"type": "object",
"required": [
"new_password",
"token"
],
"properties": {
"new_password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"token": {
"type": "string"
}
}
},
"dto.RevokeTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"dto.UpdateEmailRequest": {
"type": "object",
"required": [
"email"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
}
}
},
"dto.UpdatePasswordRequest": {
"type": "object",
"required": [
"current_password",
"new_password"
],
"properties": {
"current_password": {
"type": "string"
},
"new_password": {
"type": "string",
"maxLength": 128,
"minLength": 8
}
}
},
"dto.UpdatePostRequest": {
"type": "object",
"required": [
"title"
],
"properties": {
"content": {
"type": "string",
"maxLength": 10000
},
"title": {
"type": "string",
"maxLength": 200,
"minLength": 3
}
}
},
"dto.UpdateUsernameRequest": {
"type": "object",
"required": [
"username"
],
"properties": {
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"handlers.APIInfo": {
"type": "object",
"properties": {
@@ -1916,50 +2133,6 @@
}
}
},
"handlers.ConfirmAccountDeletionRequest": {
"type": "object",
"properties": {
"delete_posts": {
"type": "boolean"
},
"token": {
"type": "string"
}
}
},
"handlers.CreatePostRequest": {
"type": "object",
"properties": {
"content": {
"type": "string"
},
"title": {
"type": "string"
},
"url": {
"type": "string"
}
}
},
"handlers.ForgotPasswordRequest": {
"type": "object",
"properties": {
"username_or_email": {
"type": "string"
}
}
},
"handlers.LoginRequest": {
"type": "object",
"properties": {
"password": {
"type": "string"
},
"username": {
"type": "string"
}
}
},
"handlers.PostResponse": {
"type": "object",
"properties": {
@@ -1975,101 +2148,6 @@
}
}
},
"handlers.RefreshTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"handlers.RegisterRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
},
"password": {
"type": "string"
},
"username": {
"type": "string"
}
}
},
"handlers.ResendVerificationRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
}
}
},
"handlers.ResetPasswordRequest": {
"type": "object",
"properties": {
"new_password": {
"type": "string"
},
"token": {
"type": "string"
}
}
},
"handlers.RevokeTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"handlers.UpdateEmailRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
}
}
},
"handlers.UpdatePasswordRequest": {
"type": "object",
"properties": {
"current_password": {
"type": "string"
},
"new_password": {
"type": "string"
}
}
},
"handlers.UpdatePostRequest": {
"type": "object",
"properties": {
"content": {
"type": "string"
},
"title": {
"type": "string"
}
}
},
"handlers.UpdateUsernameRequest": {
"type": "object",
"properties": {
"username": {
"type": "string"
}
}
},
"handlers.UserResponse": {
"type": "object",
"properties": {
@@ -2085,21 +2163,6 @@
}
}
},
"handlers.VoteRequest": {
"description": "Vote request with type field. All votes are handled the same way.",
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": [
"up",
"down",
"none"
],
"example": "up"
}
}
},
"handlers.VoteResponse": {
"type": "object",
"properties": {

View File

@@ -1,5 +1,156 @@
basePath: /api
definitions:
dto.CastVoteRequest:
properties:
type:
enum:
- up
- down
- none
type: string
required:
- type
type: object
dto.ConfirmAccountDeletionRequest:
properties:
delete_posts:
type: boolean
token:
type: string
required:
- token
type: object
dto.CreatePostRequest:
properties:
content:
maxLength: 10000
type: string
title:
maxLength: 200
minLength: 3
type: string
url:
maxLength: 2048
type: string
required:
- url
type: object
dto.ForgotPasswordRequest:
properties:
username_or_email:
type: string
required:
- username_or_email
type: object
dto.LoginRequest:
properties:
password:
maxLength: 128
minLength: 8
type: string
username:
maxLength: 50
minLength: 3
type: string
required:
- password
- username
type: object
dto.RefreshTokenRequest:
properties:
refresh_token:
example: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
type: string
required:
- refresh_token
type: object
dto.RegisterRequest:
properties:
email:
maxLength: 254
type: string
password:
maxLength: 128
minLength: 8
type: string
username:
maxLength: 50
minLength: 3
type: string
required:
- email
- password
- username
type: object
dto.ResendVerificationRequest:
properties:
email:
maxLength: 254
type: string
required:
- email
type: object
dto.ResetPasswordRequest:
properties:
new_password:
maxLength: 128
minLength: 8
type: string
token:
type: string
required:
- new_password
- token
type: object
dto.RevokeTokenRequest:
properties:
refresh_token:
example: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
type: string
required:
- refresh_token
type: object
dto.UpdateEmailRequest:
properties:
email:
maxLength: 254
type: string
required:
- email
type: object
dto.UpdatePasswordRequest:
properties:
current_password:
type: string
new_password:
maxLength: 128
minLength: 8
type: string
required:
- current_password
- new_password
type: object
dto.UpdatePostRequest:
properties:
content:
maxLength: 10000
type: string
title:
maxLength: 200
minLength: 3
type: string
required:
- title
type: object
dto.UpdateUsernameRequest:
properties:
username:
maxLength: 50
minLength: 3
type: string
required:
- username
type: object
handlers.APIInfo:
properties:
data: {}
@@ -70,34 +221,6 @@ definitions:
success:
type: boolean
type: object
handlers.ConfirmAccountDeletionRequest:
properties:
delete_posts:
type: boolean
token:
type: string
type: object
handlers.CreatePostRequest:
properties:
content:
type: string
title:
type: string
url:
type: string
type: object
handlers.ForgotPasswordRequest:
properties:
username_or_email:
type: string
type: object
handlers.LoginRequest:
properties:
password:
type: string
username:
type: string
type: object
handlers.PostResponse:
properties:
data: {}
@@ -108,67 +231,6 @@ definitions:
success:
type: boolean
type: object
handlers.RefreshTokenRequest:
properties:
refresh_token:
example: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
type: string
required:
- refresh_token
type: object
handlers.RegisterRequest:
properties:
email:
type: string
password:
type: string
username:
type: string
type: object
handlers.ResendVerificationRequest:
properties:
email:
type: string
type: object
handlers.ResetPasswordRequest:
properties:
new_password:
type: string
token:
type: string
type: object
handlers.RevokeTokenRequest:
properties:
refresh_token:
example: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
type: string
required:
- refresh_token
type: object
handlers.UpdateEmailRequest:
properties:
email:
type: string
type: object
handlers.UpdatePasswordRequest:
properties:
current_password:
type: string
new_password:
type: string
type: object
handlers.UpdatePostRequest:
properties:
content:
type: string
title:
type: string
type: object
handlers.UpdateUsernameRequest:
properties:
username:
type: string
type: object
handlers.UserResponse:
properties:
data: {}
@@ -179,17 +241,6 @@ definitions:
success:
type: boolean
type: object
handlers.VoteRequest:
description: Vote request with type field. All votes are handled the same way.
properties:
type:
enum:
- up
- down
- none
example: up
type: string
type: object
handlers.VoteResponse:
properties:
data: {}
@@ -268,7 +319,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.ConfirmAccountDeletionRequest'
$ref: '#/definitions/dto.ConfirmAccountDeletionRequest'
produces:
- application/json
responses:
@@ -331,7 +382,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.UpdateEmailRequest'
$ref: '#/definitions/dto.UpdateEmailRequest'
produces:
- application/json
responses:
@@ -375,7 +426,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.ForgotPasswordRequest'
$ref: '#/definitions/dto.ForgotPasswordRequest'
produces:
- application/json
responses:
@@ -401,7 +452,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.LoginRequest'
$ref: '#/definitions/dto.LoginRequest'
produces:
- application/json
responses:
@@ -485,7 +536,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.UpdatePasswordRequest'
$ref: '#/definitions/dto.UpdatePasswordRequest'
produces:
- application/json
responses:
@@ -523,7 +574,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.RefreshTokenRequest'
$ref: '#/definitions/dto.RefreshTokenRequest'
produces:
- application/json
responses:
@@ -561,7 +612,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.RegisterRequest'
$ref: '#/definitions/dto.RegisterRequest'
produces:
- application/json
responses:
@@ -595,7 +646,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.ResendVerificationRequest'
$ref: '#/definitions/dto.ResendVerificationRequest'
produces:
- application/json
responses:
@@ -641,7 +692,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.ResetPasswordRequest'
$ref: '#/definitions/dto.ResetPasswordRequest'
produces:
- application/json
responses:
@@ -672,7 +723,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.RevokeTokenRequest'
$ref: '#/definitions/dto.RevokeTokenRequest'
produces:
- application/json
responses:
@@ -735,7 +786,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.UpdateUsernameRequest'
$ref: '#/definitions/dto.UpdateUsernameRequest'
produces:
- application/json
responses:
@@ -809,7 +860,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.CreatePostRequest'
$ref: '#/definitions/dto.CreatePostRequest'
produces:
- application/json
responses:
@@ -933,7 +984,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.UpdatePostRequest'
$ref: '#/definitions/dto.UpdatePostRequest'
produces:
- application/json
responses:
@@ -1070,7 +1121,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.VoteRequest'
$ref: '#/definitions/dto.CastVoteRequest'
produces:
- application/json
responses:
@@ -1260,7 +1311,7 @@ paths:
name: request
required: true
schema:
$ref: '#/definitions/handlers.RegisterRequest'
$ref: '#/definitions/dto.RegisterRequest'
produces:
- application/json
responses:

View File

@@ -15,6 +15,7 @@ type Config struct {
SMTP SMTPConfig
App AppConfig
RateLimit RateLimitConfig
CLI CLIConfig
LogDir string
PIDDir string
}
@@ -81,6 +82,10 @@ type RateLimitConfig struct {
TrustProxyHeaders bool
}
type CLIConfig struct {
JSONOutputDefault bool
}
func Load() (*Config, error) {
config := &Config{
Database: DatabaseConfig{
@@ -137,6 +142,9 @@ func Load() (*Config, error) {
MetricsLimit: getEnvAsInt("RATE_LIMIT_METRICS", 20),
TrustProxyHeaders: getEnvAsBool("RATE_LIMIT_TRUST_PROXY", false),
},
CLI: CLIConfig{
JSONOutputDefault: getEnvAsBool("CLI_JSON_OUTPUT", false),
},
LogDir: getEnv("LOG_DIR", "/var/log/"),
PIDDir: getEnv("PID_DIR", "/run"),
}

View File

@@ -995,3 +995,63 @@ func TestLoadWithInvalidBcryptCost(t *testing.T) {
})
}
}
func TestCLIConfigJSONOutput(t *testing.T) {
tests := []struct {
name string
envValue string
expectedOutput bool
}{
{
name: "default false when not set",
envValue: "",
expectedOutput: false,
},
{
name: "true when set to true",
envValue: "true",
expectedOutput: true,
},
{
name: "false when set to false",
envValue: "false",
expectedOutput: false,
},
{
name: "true when set to 1",
envValue: "1",
expectedOutput: true,
},
{
name: "false when set to 0",
envValue: "0",
expectedOutput: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("DB_PASSWORD", "password")
t.Setenv("SMTP_HOST", "smtp.example.com")
t.Setenv("SMTP_PORT", "2525")
t.Setenv("SMTP_FROM", "no-reply@example.com")
t.Setenv("ADMIN_EMAIL", "admin@example.com")
t.Setenv("JWT_SECRET", "this-is-a-very-secure-jwt-secret-key-that-is-long-enough")
if tt.envValue != "" {
t.Setenv("CLI_JSON_OUTPUT", tt.envValue)
} else {
os.Unsetenv("CLI_JSON_OUTPUT")
}
cfg, err := Load()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.CLI.JSONOutputDefault != tt.expectedOutput {
t.Fatalf("expected CLI.JSONOutputDefault to be %v, got %v", tt.expectedOutput, cfg.CLI.JSONOutputDefault)
}
})
}
}

View File

@@ -81,7 +81,7 @@ func TestUser_Model(t *testing.T) {
db := newTestDB(t)
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
_ = sqlDB.Close()
}
}()
@@ -195,7 +195,7 @@ func TestPost_Model(t *testing.T) {
db := newTestDB(t)
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
_ = sqlDB.Close()
}
}()
@@ -305,7 +305,7 @@ func TestVote_Model(t *testing.T) {
db := newTestDB(t)
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
_ = sqlDB.Close()
}
}()
@@ -424,7 +424,7 @@ func TestRefreshToken_Model(t *testing.T) {
db := newTestDB(t)
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
_ = sqlDB.Close()
}
}()
@@ -483,7 +483,7 @@ func TestAccountDeletionRequest_Model(t *testing.T) {
db := newTestDB(t)
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
_ = sqlDB.Close()
}
}()
@@ -554,7 +554,7 @@ func TestModel_SoftDelete(t *testing.T) {
db := newTestDB(t)
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
_ = sqlDB.Close()
}
}()

View File

@@ -27,24 +27,47 @@ func (g *GormDBMonitor) Name() string {
}
func (g *GormDBMonitor) Initialize(db *gorm.DB) error {
if err := db.Callback().Create().Before("gorm:create").Register("db_monitor:before_create", g.beforeCreate); err != nil {
return err
}
if err := db.Callback().Create().After("gorm:create").Register("db_monitor:after_create", g.afterCreate); err != nil {
return err
}
db.Callback().Create().Before("gorm:create").Register("db_monitor:before_create", g.beforeCreate)
db.Callback().Create().After("gorm:create").Register("db_monitor:after_create", g.afterCreate)
if err := db.Callback().Query().Before("gorm:query").Register("db_monitor:before_query", g.beforeQuery); err != nil {
return err
}
if err := db.Callback().Query().After("gorm:query").Register("db_monitor:after_query", g.afterQuery); err != nil {
return err
}
db.Callback().Query().Before("gorm:query").Register("db_monitor:before_query", g.beforeQuery)
db.Callback().Query().After("gorm:query").Register("db_monitor:after_query", g.afterQuery)
if err := db.Callback().Update().Before("gorm:update").Register("db_monitor:before_update", g.beforeUpdate); err != nil {
return err
}
if err := db.Callback().Update().After("gorm:update").Register("db_monitor:after_update", g.afterUpdate); err != nil {
return err
}
db.Callback().Update().Before("gorm:update").Register("db_monitor:before_update", g.beforeUpdate)
db.Callback().Update().After("gorm:update").Register("db_monitor:after_update", g.afterUpdate)
if err := db.Callback().Delete().Before("gorm:delete").Register("db_monitor:before_delete", g.beforeDelete); err != nil {
return err
}
if err := db.Callback().Delete().After("gorm:delete").Register("db_monitor:after_delete", g.afterDelete); err != nil {
return err
}
db.Callback().Delete().Before("gorm:delete").Register("db_monitor:before_delete", g.beforeDelete)
db.Callback().Delete().After("gorm:delete").Register("db_monitor:after_delete", g.afterDelete)
if err := db.Callback().Row().Before("gorm:row").Register("db_monitor:before_row", g.beforeRow); err != nil {
return err
}
if err := db.Callback().Row().After("gorm:row").Register("db_monitor:after_row", g.afterRow); err != nil {
return err
}
db.Callback().Row().Before("gorm:row").Register("db_monitor:before_row", g.beforeRow)
db.Callback().Row().After("gorm:row").Register("db_monitor:after_row", g.afterRow)
db.Callback().Raw().Before("gorm:raw").Register("db_monitor:before_raw", g.beforeRaw)
db.Callback().Raw().After("gorm:raw").Register("db_monitor:after_raw", g.afterRaw)
if err := db.Callback().Raw().Before("gorm:raw").Register("db_monitor:before_raw", g.beforeRaw); err != nil {
return err
}
if err := db.Callback().Raw().After("gorm:raw").Register("db_monitor:after_raw", g.afterRaw); err != nil {
return err
}
return nil
}

View File

@@ -15,10 +15,6 @@ func TestNewGormDBMonitor(t *testing.T) {
monitor := middleware.NewInMemoryDBMonitor()
gormMonitor := NewGormDBMonitor(monitor)
if gormMonitor == nil {
t.Fatal("Expected non-nil GormDBMonitor")
}
if gormMonitor.monitor != monitor {
t.Error("Expected monitor to be set correctly")
}
@@ -40,7 +36,7 @@ func TestGormDBMonitor_Initialize(t *testing.T) {
db := newTestDB(t)
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
_ = sqlDB.Close()
}
}()
@@ -56,7 +52,7 @@ func TestGormDBMonitor_InitializeWithNilMonitor(t *testing.T) {
db := newTestDB(t)
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
_ = sqlDB.Close()
}
}()
@@ -73,7 +69,7 @@ func TestGormDBMonitor_Callbacks(t *testing.T) {
db := newTestDB(t)
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
_ = sqlDB.Close()
}
}()
@@ -114,7 +110,7 @@ func TestGormDBMonitor_CallbacksWithNilMonitor(t *testing.T) {
db := newTestDB(t)
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
_ = sqlDB.Close()
}
}()
@@ -142,7 +138,7 @@ func TestGormDBMonitor_BuildQueryString(t *testing.T) {
db := newTestDB(t)
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
_ = sqlDB.Close()
}
}()
@@ -280,7 +276,7 @@ func TestGormDBMonitor_WithRealDatabase(t *testing.T) {
}
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
_ = sqlDB.Close()
}
}()

View File

@@ -0,0 +1,51 @@
package dto
type LoginRequest struct {
Username string `json:"username" validate:"required,min=3,max=50"`
Password string `json:"password" validate:"required,min=8,max=128"`
}
type RegisterRequest struct {
Username string `json:"username" validate:"required,min=3,max=50"`
Email string `json:"email" validate:"required,email,max=254"`
Password string `json:"password" validate:"required,min=8,max=128"`
}
type ResendVerificationRequest struct {
Email string `json:"email" validate:"required,email,max=254"`
}
type ForgotPasswordRequest struct {
UsernameOrEmail string `json:"username_or_email" validate:"required"`
}
type ResetPasswordRequest struct {
Token string `json:"token" validate:"required"`
NewPassword string `json:"new_password" validate:"required,min=8,max=128"`
}
type UpdateEmailRequest struct {
Email string `json:"email" validate:"required,email,max=254"`
}
type UpdateUsernameRequest struct {
Username string `json:"username" validate:"required,min=3,max=50"`
}
type UpdatePasswordRequest struct {
CurrentPassword string `json:"current_password" validate:"required"`
NewPassword string `json:"new_password" validate:"required,min=8,max=128"`
}
type ConfirmAccountDeletionRequest struct {
Token string `json:"token" validate:"required"`
DeletePosts bool `json:"delete_posts"`
}
type RefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." validate:"required"`
}
type RevokeTokenRequest struct {
RefreshToken string `json:"refresh_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." validate:"required"`
}

View File

@@ -0,0 +1,12 @@
package dto
type CreatePostRequest struct {
Title string `json:"title" validate:"omitempty,min=3,max=200"`
URL string `json:"url" validate:"required,url,max=2048"`
Content string `json:"content" validate:"omitempty,max=10000"`
}
type UpdatePostRequest struct {
Title string `json:"title" validate:"required,min=3,max=200"`
Content string `json:"content" validate:"omitempty,max=10000"`
}

View File

@@ -6,6 +6,10 @@ import (
"goyco/internal/database"
)
type CastVoteRequest struct {
Type string `json:"type" validate:"required,oneof=up down none"`
}
type VoteDTO struct {
ID uint `json:"id"`
UserID *uint `json:"user_id,omitempty"`

View File

@@ -11,11 +11,12 @@ import (
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"goyco/internal/config"
"goyco/internal/repositories"
"goyco/internal/services"
"goyco/internal/testutils"
"github.com/golang-jwt/jwt/v5"
)
func TestE2E_APIRegistration(t *testing.T) {
@@ -377,7 +378,7 @@ func TestE2E_RefreshTokenFlow(t *testing.T) {
for i := range 3 {
var statusCode int
var newAccessToken string
for attempt := 0; attempt < 3; attempt++ {
for attempt := range 3 {
newAccessToken, statusCode = authClient.RefreshAccessToken(t, testutils.GenerateTestIP())
if statusCode != http.StatusTooManyRequests {
break

View File

@@ -112,37 +112,37 @@ func newInMemoryRoundTripper(handler http.Handler) http.RoundTripper {
return &inMemoryRoundTripper{handler: handler}
}
func (rt *inMemoryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
func (rt *inMemoryRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
if rt == nil || rt.handler == nil {
return nil, fmt.Errorf("in-memory round tripper not initialized")
}
var bodyBytes []byte
if req.Body != nil && req.Body != http.NoBody {
defer req.Body.Close()
if request.Body != nil && request.Body != http.NoBody {
defer request.Body.Close()
var err error
bodyBytes, err = io.ReadAll(req.Body)
bodyBytes, err = io.ReadAll(request.Body)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %w", err)
}
}
if len(bodyBytes) > 0 {
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
} else {
req.Body = http.NoBody
request.Body = http.NoBody
}
clonedReq := req.Clone(req.Context())
clonedRequest := request.Clone(request.Context())
if len(bodyBytes) > 0 {
clonedReq.Body = io.NopCloser(bytes.NewReader(bodyBytes))
clonedRequest.Body = io.NopCloser(bytes.NewReader(bodyBytes))
} else {
clonedReq.Body = http.NoBody
clonedRequest.Body = http.NoBody
}
clonedReq.RequestURI = clonedReq.URL.RequestURI()
clonedRequest.RequestURI = clonedRequest.URL.RequestURI()
recorder := httptest.NewRecorder()
rt.handler.ServeHTTP(recorder, clonedReq)
rt.handler.ServeHTTP(recorder, clonedRequest)
resp := recorder.Result()
return resp, nil
}
@@ -203,10 +203,7 @@ func uniqueUsername(t *testing.T, prefix string) string {
fullPrefix := fmt.Sprintf("%s_%s", filePrefix, prefix)
username := fmt.Sprintf("%s_%s", fullPrefix, uniqueTestID(t))
if len(username) > 50 {
maxIDLength := 50 - len(fullPrefix) - 1
if maxIDLength < 0 {
maxIDLength = 0
}
maxIDLength := max(50-len(fullPrefix)-1, 0)
testID := uniqueTestID(t)
if len(testID) > maxIDLength {
testID = testID[:maxIDLength]
@@ -253,7 +250,7 @@ func tokenHash(token string) string {
func retryOnRateLimit(t *testing.T, maxRetries int, operation func() int) int {
t.Helper()
for attempt := 0; attempt < maxRetries; attempt++ {
for attempt := range maxRetries {
statusCode := operation()
if statusCode != http.StatusTooManyRequests {
return statusCode

View File

@@ -15,22 +15,22 @@ func TestE2E_CompressionMiddleware(t *testing.T) {
ctx := setupTestContext(t)
t.Run("compression_enabled_with_accept_encoding", func(t *testing.T) {
req, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
request, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Accept-Encoding", "gzip")
testutils.WithStandardHeaders(req)
request.Header.Set("Accept-Encoding", "gzip")
testutils.WithStandardHeaders(request)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
defer response.Body.Close()
contentEncoding := resp.Header.Get("Content-Encoding")
contentEncoding := response.Header.Get("Content-Encoding")
if contentEncoding == "gzip" {
body, err := io.ReadAll(resp.Body)
body, err := io.ReadAll(response.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
@@ -57,19 +57,19 @@ func TestE2E_CompressionMiddleware(t *testing.T) {
})
t.Run("no_compression_without_accept_encoding", func(t *testing.T) {
req, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
request, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
testutils.WithStandardHeaders(req)
testutils.WithStandardHeaders(request)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
defer response.Body.Close()
contentEncoding := resp.Header.Get("Content-Encoding")
contentEncoding := response.Header.Get("Content-Encoding")
if contentEncoding == "gzip" {
t.Error("Expected no compression without Accept-Encoding header")
}
@@ -85,22 +85,22 @@ func TestE2E_CompressionMiddleware(t *testing.T) {
gz.Write([]byte(postData))
gz.Close()
req, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", &buf)
request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", &buf)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Encoding", "gzip")
testutils.WithStandardHeaders(req)
req.Header.Set("Authorization", "Bearer "+authClient.Token)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Content-Encoding", "gzip")
testutils.WithStandardHeaders(request)
request.Header.Set("Authorization", "Bearer "+authClient.Token)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
defer response.Body.Close()
switch resp.StatusCode {
switch response.StatusCode {
case http.StatusBadRequest:
t.Log("Decompression middleware rejected invalid gzip")
case http.StatusCreated, http.StatusOK:
@@ -113,37 +113,37 @@ func TestE2E_CacheMiddleware(t *testing.T) {
ctx := setupTestContext(t)
t.Run("cache_miss_then_hit", func(t *testing.T) {
req1, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
firstRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
testutils.WithStandardHeaders(req1)
testutils.WithStandardHeaders(firstRequest)
resp1, err := ctx.client.Do(req1)
firstResponse, err := ctx.client.Do(firstRequest)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
resp1.Body.Close()
firstResponse.Body.Close()
cacheStatus1 := resp1.Header.Get("X-Cache")
if cacheStatus1 == "HIT" {
firstCacheStatus := firstResponse.Header.Get("X-Cache")
if firstCacheStatus == "HIT" {
t.Log("First request was cached (unexpected but acceptable)")
}
req2, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
secondRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
testutils.WithStandardHeaders(req2)
testutils.WithStandardHeaders(secondRequest)
resp2, err := ctx.client.Do(req2)
secondResponse, err := ctx.client.Do(secondRequest)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp2.Body.Close()
defer secondResponse.Body.Close()
cacheStatus2 := resp2.Header.Get("X-Cache")
if cacheStatus2 == "HIT" {
secondCacheStatus := secondResponse.Header.Get("X-Cache")
if secondCacheStatus == "HIT" {
t.Log("Second request was served from cache")
}
})
@@ -152,48 +152,48 @@ func TestE2E_CacheMiddleware(t *testing.T) {
testUser := ctx.createUserWithCleanup(t, "cacheuser", "StrongPass123!")
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
req1, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
firstRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
testutils.WithStandardHeaders(req1)
req1.Header.Set("Authorization", "Bearer "+authClient.Token)
testutils.WithStandardHeaders(firstRequest)
firstRequest.Header.Set("Authorization", "Bearer "+authClient.Token)
resp1, err := ctx.client.Do(req1)
firstResponse, err := ctx.client.Do(firstRequest)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
resp1.Body.Close()
firstResponse.Body.Close()
postData := `{"title":"Cache Invalidation Test","url":"https://example.com/cache","content":"Test"}`
req2, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
secondRequest, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req2.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req2)
req2.Header.Set("Authorization", "Bearer "+authClient.Token)
secondRequest.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(secondRequest)
secondRequest.Header.Set("Authorization", "Bearer "+authClient.Token)
resp2, err := ctx.client.Do(req2)
secondResponse, err := ctx.client.Do(secondRequest)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
resp2.Body.Close()
secondResponse.Body.Close()
req3, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
thirdRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
testutils.WithStandardHeaders(req3)
req3.Header.Set("Authorization", "Bearer "+authClient.Token)
testutils.WithStandardHeaders(thirdRequest)
thirdRequest.Header.Set("Authorization", "Bearer "+authClient.Token)
resp3, err := ctx.client.Do(req3)
thirdResponse, err := ctx.client.Do(thirdRequest)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp3.Body.Close()
defer thirdResponse.Body.Close()
cacheStatus := resp3.Header.Get("X-Cache")
cacheStatus := thirdResponse.Header.Get("X-Cache")
if cacheStatus == "HIT" {
t.Log("Cache was invalidated after POST")
}
@@ -204,23 +204,23 @@ func TestE2E_CSRFProtection(t *testing.T) {
ctx := setupTestContext(t)
t.Run("csrf_protection_for_non_api_routes", func(t *testing.T) {
req, err := http.NewRequest("POST", ctx.baseURL+"/auth/login", strings.NewReader(`{"username":"test","password":"test"}`))
request, err := http.NewRequest("POST", ctx.baseURL+"/auth/login", strings.NewReader(`{"username":"test","password":"test"}`))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req)
request.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(request)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
defer response.Body.Close()
if resp.StatusCode == http.StatusForbidden {
if response.StatusCode == http.StatusForbidden {
t.Log("CSRF protection active for non-API routes")
} else {
t.Logf("CSRF check result: status %d", resp.StatusCode)
t.Logf("CSRF check result: status %d", response.StatusCode)
}
})
@@ -229,39 +229,39 @@ func TestE2E_CSRFProtection(t *testing.T) {
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
postData := `{"title":"CSRF Test","url":"https://example.com/csrf","content":"Test"}`
req, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req)
req.Header.Set("Authorization", "Bearer "+authClient.Token)
request.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(request)
request.Header.Set("Authorization", "Bearer "+authClient.Token)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
defer response.Body.Close()
if resp.StatusCode == http.StatusForbidden {
if response.StatusCode == http.StatusForbidden {
t.Error("API routes should bypass CSRF protection")
}
})
t.Run("csrf_allows_get_requests", func(t *testing.T) {
req, err := http.NewRequest("GET", ctx.baseURL+"/auth/login", nil)
request, err := http.NewRequest("GET", ctx.baseURL+"/auth/login", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
testutils.WithStandardHeaders(req)
testutils.WithStandardHeaders(request)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
defer response.Body.Close()
if resp.StatusCode == http.StatusForbidden {
if response.StatusCode == http.StatusForbidden {
t.Error("GET requests should not require CSRF token")
}
})
@@ -276,21 +276,21 @@ func TestE2E_RequestSizeLimit(t *testing.T) {
smallData := strings.Repeat("a", 100)
postData := `{"title":"` + smallData + `","url":"https://example.com","content":"test"}`
req, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req)
req.Header.Set("Authorization", "Bearer "+authClient.Token)
request.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(request)
request.Header.Set("Authorization", "Bearer "+authClient.Token)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
defer response.Body.Close()
if resp.StatusCode == http.StatusRequestEntityTooLarge {
if response.StatusCode == http.StatusRequestEntityTooLarge {
t.Error("Small request should not exceed size limit")
}
})
@@ -301,24 +301,24 @@ func TestE2E_RequestSizeLimit(t *testing.T) {
largeData := strings.Repeat("a", 2*1024*1024)
postData := `{"title":"test","url":"https://example.com","content":"` + largeData + `"}`
req, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req)
req.Header.Set("Authorization", "Bearer "+authClient.Token)
request.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(request)
request.Header.Set("Authorization", "Bearer "+authClient.Token)
resp, err := ctx.client.Do(req)
response, err := ctx.client.Do(request)
if err != nil {
return
}
defer resp.Body.Close()
defer response.Body.Close()
if resp.StatusCode == http.StatusRequestEntityTooLarge {
if response.StatusCode == http.StatusRequestEntityTooLarge {
t.Log("Request size limit enforced correctly")
} else {
t.Logf("Request size limit check result: status %d", resp.StatusCode)
t.Logf("Request size limit check result: status %d", response.StatusCode)
}
})
}

View File

@@ -17,7 +17,7 @@ func TestE2E_RateLimitingHeaders(t *testing.T) {
t.Run("rate_limit_headers_present", func(t *testing.T) {
testUser := ctx.createUserWithCleanup(t, "ratelimituser", "StrongPass123!")
for i := 0; i < 3; i++ {
for range 3 {
req, err := http.NewRequest("POST", ctx.baseURL+"/api/auth/login", strings.NewReader(`{"username":"`+testUser.Username+`","password":"StrongPass123!"}`))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
@@ -53,7 +53,7 @@ func TestE2E_RateLimitingHeaders(t *testing.T) {
testUser := ctx.createUserWithCleanup(t, "ratelimituser2", "StrongPass123!")
testIP := testutils.GenerateTestIP()
for i := 0; i < 4; i++ {
for i := range 4 {
req, err := http.NewRequest("POST", ctx.baseURL+"/api/auth/login", strings.NewReader(`{"username":"`+testUser.Username+`","password":"StrongPass123!"}`))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
@@ -95,7 +95,7 @@ func TestE2E_RateLimitResetBehavior(t *testing.T) {
testUser := ctx.createUserWithCleanup(t, "resetuser", "StrongPass123!")
testIP := testutils.GenerateTestIP()
for i := 0; i < 2; i++ {
for range 2 {
req, err := http.NewRequest("POST", ctx.baseURL+"/api/auth/login", strings.NewReader(`{"username":"`+testUser.Username+`","password":"StrongPass123!"}`))
if err != nil {
continue
@@ -178,7 +178,7 @@ func TestE2E_RateLimitDifferentScenarios(t *testing.T) {
successCount1 := 0
successCount2 := 0
for i := 0; i < 5; i++ {
for range 5 {
req1, _ := http.NewRequest("POST", ctx.baseURL+"/api/auth/login", strings.NewReader(`{"username":"`+testUser.Username+`","password":"StrongPass123!"}`))
req1.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req1)
@@ -221,7 +221,7 @@ func TestE2E_RateLimitDifferentScenarios(t *testing.T) {
successCount1 := 0
successCount2 := 0
for i := 0; i < 10; i++ {
for range 10 {
req1, _ := http.NewRequest("GET", ctx.baseURL+"/api/auth/me", nil)
testutils.WithStandardHeaders(req1)
req1.Header.Set("Authorization", "Bearer "+authClient1.Token)

View File

@@ -432,7 +432,7 @@ func TestE2E_TokenReplayAttack(t *testing.T) {
token := authClient.Token
t.Run("same_token_multiple_times", func(t *testing.T) {
for i := 0; i < 5; i++ {
for i := range 5 {
statusCode := ctx.makeRequestWithToken(t, token)
if statusCode != http.StatusOK {
t.Errorf("Expected token to work multiple times (replay %d), got status %d", i+1, statusCode)

View File

@@ -64,62 +64,6 @@ type AuthUserSummary struct {
Locked bool `json:"locked" example:"false"`
}
type LoginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
}
type RegisterRequest struct {
Username string `json:"username"`
Email string `json:"email"`
Password string `json:"password"`
}
type CreatePostRequest struct {
Title string `json:"title"`
URL string `json:"url"`
Content string `json:"content"`
}
type ResendVerificationRequest struct {
Email string `json:"email"`
}
type ForgotPasswordRequest struct {
UsernameOrEmail string `json:"username_or_email"`
}
type ResetPasswordRequest struct {
Token string `json:"token"`
NewPassword string `json:"new_password"`
}
type UpdateEmailRequest struct {
Email string `json:"email"`
}
type UpdateUsernameRequest struct {
Username string `json:"username"`
}
type UpdatePasswordRequest struct {
CurrentPassword string `json:"current_password"`
NewPassword string `json:"new_password"`
}
type ConfirmAccountDeletionRequest struct {
Token string `json:"token"`
DeletePosts bool `json:"delete_posts"`
}
type RefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." binding:"required"`
}
type RevokeTokenRequest struct {
RefreshToken string `json:"refresh_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." binding:"required"`
}
func NewAuthHandler(authService AuthServiceInterface, userRepo repositories.UserRepository) *AuthHandler {
return &AuthHandler{
authService: authService,
@@ -132,7 +76,7 @@ func NewAuthHandler(authService AuthServiceInterface, userRepo repositories.User
// @Tags auth
// @Accept json
// @Produce json
// @Param request body LoginRequest true "Login credentials"
// @Param request body dto.LoginRequest true "Login credentials"
// @Success 200 {object} AuthTokensResponse "Authentication successful"
// @Failure 400 {object} AuthResponse "Invalid request data or validation failed"
// @Failure 401 {object} AuthResponse "Invalid credentials"
@@ -140,23 +84,15 @@ func NewAuthHandler(authService AuthServiceInterface, userRepo repositories.User
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/login [post]
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
var req struct {
Username string `json:"username"`
Password string `json:"password"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.LoginRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
username := security.SanitizeUsername(req.Username)
password := strings.TrimSpace(req.Password)
if username == "" || password == "" {
SendErrorResponse(w, "Username and password are required", http.StatusBadRequest)
return
}
if err := validation.ValidatePassword(password); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
@@ -175,20 +111,16 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
// @Tags auth
// @Accept json
// @Produce json
// @Param request body RegisterRequest true "Registration data"
// @Param request body dto.RegisterRequest true "Registration data"
// @Success 201 {object} AuthResponse "Registration successful"
// @Failure 400 {object} AuthResponse "Invalid request data or validation failed"
// @Failure 409 {object} AuthResponse "Username or email already exists"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/register [post]
func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
var req struct {
Username string `json:"username"`
Email string `json:"email"`
Password string `json:"password"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.RegisterRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -196,11 +128,6 @@ func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
email := strings.TrimSpace(req.Email)
password := strings.TrimSpace(req.Password)
if username == "" || email == "" || password == "" {
SendErrorResponse(w, "Username, email, and password are required", http.StatusBadRequest)
return
}
username = security.SanitizeUsername(username)
if err := validation.ValidateUsername(username); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
@@ -280,7 +207,7 @@ func (h *AuthHandler) ConfirmEmail(w http.ResponseWriter, r *http.Request) {
// @Tags auth
// @Accept json
// @Produce json
// @Param request body ResendVerificationRequest true "Email address"
// @Param request body dto.ResendVerificationRequest true "Email address"
// @Success 200 {object} AuthResponse
// @Failure 400 {object} AuthResponse
// @Failure 404 {object} AuthResponse
@@ -290,15 +217,14 @@ func (h *AuthHandler) ConfirmEmail(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} AuthResponse
// @Router /api/auth/resend-verification [post]
func (h *AuthHandler) ResendVerificationEmail(w http.ResponseWriter, r *http.Request) {
var req struct {
Email string `json:"email"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.ResendVerificationRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
email := strings.TrimSpace(req.Email)
if email == "" {
SendErrorResponse(w, "Email address is required", http.StatusBadRequest)
return
@@ -359,20 +285,19 @@ func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) {
// @Tags auth
// @Accept json
// @Produce json
// @Param request body ForgotPasswordRequest true "Username or email"
// @Param request body dto.ForgotPasswordRequest true "Username or email"
// @Success 200 {object} AuthResponse "Password reset email sent if account exists"
// @Failure 400 {object} AuthResponse "Invalid request data"
// @Router /api/auth/forgot-password [post]
func (h *AuthHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Request) {
var req struct {
UsernameOrEmail string `json:"username_or_email"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.ForgotPasswordRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
usernameOrEmail := strings.TrimSpace(req.UsernameOrEmail)
if usernameOrEmail == "" {
SendErrorResponse(w, "Username or email is required", http.StatusBadRequest)
return
@@ -389,18 +314,15 @@ func (h *AuthHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Reques
// @Tags auth
// @Accept json
// @Produce json
// @Param request body ResetPasswordRequest true "Password reset data"
// @Param request body dto.ResetPasswordRequest true "Password reset data"
// @Success 200 {object} AuthResponse "Password reset successfully"
// @Failure 400 {object} AuthResponse "Invalid or expired token, or validation failed"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/reset-password [post]
func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
var req struct {
Token string `json:"token"`
NewPassword string `json:"new_password"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.ResetPasswordRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -408,17 +330,12 @@ func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
newPassword := strings.TrimSpace(req.NewPassword)
if token == "" {
SendErrorResponse(w, "Reset token is required", http.StatusBadRequest)
SendErrorResponse(w, "Token is required", http.StatusBadRequest)
return
}
if newPassword == "" {
SendErrorResponse(w, "New password is required", http.StatusBadRequest)
return
}
if len(newPassword) < 8 {
SendErrorResponse(w, "Password must be at least 8 characters long", http.StatusBadRequest)
if err := validation.ValidatePassword(newPassword); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
@@ -443,7 +360,7 @@ func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body UpdateEmailRequest true "New email address"
// @Param request body dto.UpdateEmailRequest true "New email address"
// @Success 200 {object} AuthResponse
// @Failure 400 {object} AuthResponse
// @Failure 401 {object} AuthResponse
@@ -457,11 +374,9 @@ func (h *AuthHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
return
}
var req struct {
Email string `json:"email"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.UpdateEmailRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -498,7 +413,7 @@ func (h *AuthHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body UpdateUsernameRequest true "New username"
// @Param request body dto.UpdateUsernameRequest true "New username"
// @Success 200 {object} AuthResponse
// @Failure 400 {object} AuthResponse
// @Failure 401 {object} AuthResponse
@@ -511,11 +426,9 @@ func (h *AuthHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
return
}
var req struct {
Username string `json:"username"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.UpdateUsernameRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -548,7 +461,7 @@ func (h *AuthHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body UpdatePasswordRequest true "Password update data"
// @Param request body dto.UpdatePasswordRequest true "Password update data"
// @Success 200 {object} AuthResponse
// @Failure 400 {object} AuthResponse
// @Failure 401 {object} AuthResponse
@@ -560,12 +473,9 @@ func (h *AuthHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) {
return
}
var req struct {
CurrentPassword string `json:"current_password"`
NewPassword string `json:"new_password"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.UpdatePasswordRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -633,23 +543,21 @@ func (h *AuthHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
// @Tags auth
// @Accept json
// @Produce json
// @Param request body ConfirmAccountDeletionRequest true "Account deletion data"
// @Param request body dto.ConfirmAccountDeletionRequest true "Account deletion data"
// @Success 200 {object} AuthResponse "Account deleted successfully"
// @Failure 400 {object} AuthResponse "Invalid or expired token"
// @Failure 503 {object} AuthResponse "Email delivery unavailable"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/account/confirm [post]
func (h *AuthHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Request) {
var req struct {
Token string `json:"token"`
DeletePosts bool `json:"delete_posts"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.ConfirmAccountDeletionRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
token := strings.TrimSpace(req.Token)
if token == "" {
SendErrorResponse(w, "Deletion token is required", http.StatusBadRequest)
return
@@ -694,7 +602,7 @@ func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
// @Tags auth
// @Accept json
// @Produce json
// @Param request body RefreshTokenRequest true "Refresh token data"
// @Param request body dto.RefreshTokenRequest true "Refresh token data"
// @Success 200 {object} AuthTokensResponse "Token refreshed successfully"
// @Failure 400 {object} AuthResponse "Invalid request body or missing refresh token"
// @Failure 401 {object} AuthResponse "Invalid or expired refresh token"
@@ -702,13 +610,13 @@ func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/refresh [post]
func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
var req RefreshTokenRequest
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.RefreshTokenRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
if strings.TrimSpace(req.RefreshToken) == "" {
if req.RefreshToken == "" {
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
return
}
@@ -727,20 +635,20 @@ func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body RevokeTokenRequest true "Token revocation data"
// @Param request body dto.RevokeTokenRequest true "Token revocation data"
// @Success 200 {object} AuthResponse "Token revoked successfully"
// @Failure 400 {object} AuthResponse "Invalid request body or missing refresh token"
// @Failure 401 {object} AuthResponse "Invalid or expired access token"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/revoke [post]
func (h *AuthHandler) RevokeToken(w http.ResponseWriter, r *http.Request) {
var req RevokeTokenRequest
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.RevokeTokenRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
if strings.TrimSpace(req.RefreshToken) == "" {
if req.RefreshToken == "" {
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
return
}
@@ -782,28 +690,28 @@ func (h *AuthHandler) RevokeAllTokens(w http.ResponseWriter, r *http.Request) {
func (h *AuthHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
if config.GeneralRateLimit != nil {
rateLimited := config.GeneralRateLimit(r)
rateLimited.Post("/auth/refresh", h.RefreshToken)
rateLimited.Post("/auth/refresh", WithValidation[dto.RefreshTokenRequest](config.ValidationMiddleware, h.RefreshToken))
rateLimited.Get("/auth/confirm", h.ConfirmEmail)
rateLimited.Post("/auth/resend-verification", h.ResendVerificationEmail)
rateLimited.Post("/auth/resend-verification", WithValidation[dto.ResendVerificationRequest](config.ValidationMiddleware, h.ResendVerificationEmail))
} else {
r.Post("/auth/refresh", h.RefreshToken)
r.Post("/auth/refresh", WithValidation[dto.RefreshTokenRequest](config.ValidationMiddleware, h.RefreshToken))
r.Get("/auth/confirm", h.ConfirmEmail)
r.Post("/auth/resend-verification", h.ResendVerificationEmail)
r.Post("/auth/resend-verification", WithValidation[dto.ResendVerificationRequest](config.ValidationMiddleware, h.ResendVerificationEmail))
}
if config.AuthRateLimit != nil {
rateLimited := config.AuthRateLimit(r)
rateLimited.Post("/auth/register", h.Register)
rateLimited.Post("/auth/login", h.Login)
rateLimited.Post("/auth/forgot-password", h.RequestPasswordReset)
rateLimited.Post("/auth/reset-password", h.ResetPassword)
rateLimited.Post("/auth/account/confirm", h.ConfirmAccountDeletion)
rateLimited.Post("/auth/register", WithValidation[dto.RegisterRequest](config.ValidationMiddleware, h.Register))
rateLimited.Post("/auth/login", WithValidation[dto.LoginRequest](config.ValidationMiddleware, h.Login))
rateLimited.Post("/auth/forgot-password", WithValidation[dto.ForgotPasswordRequest](config.ValidationMiddleware, h.RequestPasswordReset))
rateLimited.Post("/auth/reset-password", WithValidation[dto.ResetPasswordRequest](config.ValidationMiddleware, h.ResetPassword))
rateLimited.Post("/auth/account/confirm", WithValidation[dto.ConfirmAccountDeletionRequest](config.ValidationMiddleware, h.ConfirmAccountDeletion))
} else {
r.Post("/auth/register", h.Register)
r.Post("/auth/login", h.Login)
r.Post("/auth/forgot-password", h.RequestPasswordReset)
r.Post("/auth/reset-password", h.ResetPassword)
r.Post("/auth/account/confirm", h.ConfirmAccountDeletion)
r.Post("/auth/register", WithValidation[dto.RegisterRequest](config.ValidationMiddleware, h.Register))
r.Post("/auth/login", WithValidation[dto.LoginRequest](config.ValidationMiddleware, h.Login))
r.Post("/auth/forgot-password", WithValidation[dto.ForgotPasswordRequest](config.ValidationMiddleware, h.RequestPasswordReset))
r.Post("/auth/reset-password", WithValidation[dto.ResetPasswordRequest](config.ValidationMiddleware, h.ResetPassword))
r.Post("/auth/account/confirm", WithValidation[dto.ConfirmAccountDeletionRequest](config.ValidationMiddleware, h.ConfirmAccountDeletion))
}
protected := r
@@ -816,10 +724,10 @@ func (h *AuthHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
protected.Get("/auth/me", h.Me)
protected.Post("/auth/logout", h.Logout)
protected.Post("/auth/revoke", h.RevokeToken)
protected.Post("/auth/revoke", WithValidation[dto.RevokeTokenRequest](config.ValidationMiddleware, h.RevokeToken))
protected.Post("/auth/revoke-all", h.RevokeAllTokens)
protected.Put("/auth/email", h.UpdateEmail)
protected.Put("/auth/username", h.UpdateUsername)
protected.Put("/auth/password", h.UpdatePassword)
protected.Put("/auth/email", WithValidation[dto.UpdateEmailRequest](config.ValidationMiddleware, h.UpdateEmail))
protected.Put("/auth/username", WithValidation[dto.UpdateUsernameRequest](config.ValidationMiddleware, h.UpdateUsername))
protected.Put("/auth/password", WithValidation[dto.UpdatePasswordRequest](config.ValidationMiddleware, h.UpdatePassword))
protected.Delete("/auth/account", h.DeleteAccount)
}

View File

@@ -252,8 +252,8 @@ func TestAuthHandlerLoginSuccess(t *testing.T) {
}
handler := newAuthHandler(repo)
body := bytes.NewBufferString(`{"username":"user","password":"Password123!"}`)
request := httptest.NewRequest(http.MethodPost, "/api/auth/login", body)
bodyStr := `{"username":"user","password":"Password123!"}`
request := createLoginRequest(bodyStr)
recorder := httptest.NewRecorder()
handler.Login(recorder, request)
@@ -274,17 +274,17 @@ func TestAuthHandlerLoginErrors(t *testing.T) {
handler := newAuthHandler(&testutils.UserRepositoryStub{})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString("invalid"))
request := createLoginRequest("invalid")
handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":" ","password":""}`))
request = createLoginRequest(`{"username":" ","password":""}`)
handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"user","password":"WrongPass123!"}`))
request = createLoginRequest(`{"username":"user","password":"WrongPass123!"}`)
handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
@@ -294,7 +294,7 @@ func TestAuthHandlerLoginErrors(t *testing.T) {
}}
handler = newAuthHandler(repo)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"user","password":"Password123!"}`))
request = createLoginRequest(`{"username":"user","password":"Password123!"}`)
handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusForbidden)
@@ -304,7 +304,7 @@ func TestAuthHandlerLoginErrors(t *testing.T) {
handler = newAuthHandler(repo)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"user","password":"Password123!"}`))
request = createLoginRequest(`{"username":"user","password":"Password123!"}`)
handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusInternalServerError)
@@ -330,8 +330,7 @@ func TestAuthHandlerRegisterSuccess(t *testing.T) {
return nil
}})
body := bytes.NewBufferString(`{"username":"newuser","email":"new@example.com","password":"Password123!"}`)
request := httptest.NewRequest(http.MethodPost, "/api/auth/register", body)
request := createRegisterRequest(`{"username":"newuser","email":"new@example.com","password":"Password123!"}`)
recorder := httptest.NewRecorder()
handler.Register(recorder, request)
@@ -354,12 +353,12 @@ func TestAuthHandlerRegisterErrors(t *testing.T) {
handler := newAuthHandler(repo)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString("invalid"))
request := createRegisterRequest("invalid")
handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString(`{"username":"","email":"","password":""}`))
request = createRegisterRequest(`{"username":"","email":"","password":""}`)
handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
@@ -368,7 +367,7 @@ func TestAuthHandlerRegisterErrors(t *testing.T) {
}}
handler = newAuthHandler(repo)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString(`{"username":"new","email":"taken@example.com","password":"Password123!"}`))
request = createRegisterRequest(`{"username":"new","email":"taken@example.com","password":"Password123!"}`)
handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
@@ -382,7 +381,7 @@ func TestAuthHandlerRegisterErrors(t *testing.T) {
}
handler = newAuthHandler(repo)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString(`{"username":"another","email":"taken@example.com","password":"Password123!"}`))
request = createRegisterRequest(`{"username":"another","email":"taken@example.com","password":"Password123!"}`)
handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
}
@@ -477,7 +476,7 @@ func TestAuthHandlerRequestPasswordReset(t *testing.T) {
}})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`{"username_or_email":"user@example.com"}`))
request := createForgotPasswordRequest(`{"username_or_email":"user@example.com"}`)
handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
@@ -495,19 +494,19 @@ func TestAuthHandlerRequestPasswordReset(t *testing.T) {
}})
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`{"username_or_email":"user"}`))
request = createForgotPasswordRequest(`{"username_or_email":"user"}`)
handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`{"username_or_email":""}`))
request = createForgotPasswordRequest(`{"username_or_email":""}`)
handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`invalid json`))
request = createForgotPasswordRequest(`invalid json`)
handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
@@ -518,25 +517,25 @@ func TestAuthHandlerResetPassword(t *testing.T) {
handler := newAuthHandler(repo)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"new_password":"NewPassword123!"}`))
request := createResetPasswordRequest(`{"new_password":"NewPassword123!"}`)
handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"token":"valid_token"}`))
request = createResetPasswordRequest(`{"token":"valid_token"}`)
handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"token":"valid_token","new_password":"short"}`))
request = createResetPasswordRequest(`{"token":"valid_token","new_password":"short"}`)
handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`invalid json`))
request = createResetPasswordRequest(`invalid json`)
handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
@@ -602,7 +601,7 @@ func TestAuthHandlerResetPasswordServiceOutcomes(t *testing.T) {
handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"token":"abc","new_password":"Password123!"}`))
request := createResetPasswordRequest(`{"token":"abc","new_password":"Password123!"}`)
recorder := httptest.NewRecorder()
handler.ResetPassword(recorder, request)
@@ -664,7 +663,7 @@ func TestAuthHandlerUpdateEmail(t *testing.T) {
userID: 1,
mockSetup: func(repo *testutils.UserRepositoryStub) {},
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body",
expectedError: "Invalid request",
},
{
name: "empty email",
@@ -702,7 +701,7 @@ func TestAuthHandlerUpdateEmail(t *testing.T) {
handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPut, "/api/auth/email", bytes.NewBufferString(tt.requestBody))
request := createUpdateEmailRequest(tt.requestBody)
if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx)
@@ -789,7 +788,7 @@ func TestAuthHandlerUpdateUsername(t *testing.T) {
handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPut, "/api/auth/username", bytes.NewBufferString(tt.requestBody))
request := createUpdateUsernameRequest(tt.requestBody)
if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx)
@@ -886,7 +885,7 @@ func TestAuthHandlerUpdatePassword(t *testing.T) {
tt.mockSetup(repo)
handler := newAuthHandler(repo)
request := httptest.NewRequest(http.MethodPut, "/api/auth/password", bytes.NewBufferString(tt.requestBody))
request := createUpdatePasswordRequest(tt.requestBody)
if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx)
@@ -984,8 +983,7 @@ func TestAuthHandlerDeleteAccount(t *testing.T) {
func TestAuthHandlerResendVerificationEmail(t *testing.T) {
makeRequest := func(body string, setup func(*mockAuthService)) (*httptest.ResponseRecorder, AuthResponse) {
request := httptest.NewRequest(http.MethodPost, "/api/auth/resend-verification", bytes.NewBufferString(body))
request = request.WithContext(context.Background())
request := createResendVerificationRequest(body)
repo := &testutils.UserRepositoryStub{}
mockService := &mockAuthService{}
@@ -1014,7 +1012,7 @@ func TestAuthHandlerResendVerificationEmail(t *testing.T) {
name: "invalid json",
body: "not-json",
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body",
expectedError: "Invalid request",
},
{
name: "missing email",
@@ -1139,7 +1137,7 @@ func TestAuthHandlerConfirmAccountDeletion(t *testing.T) {
name: "invalid json",
body: "not-json",
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body",
expectedError: "Invalid request",
},
{
name: "missing token",
@@ -1209,7 +1207,7 @@ func TestAuthHandlerConfirmAccountDeletion(t *testing.T) {
handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPost, "/api/auth/account/confirm", bytes.NewBufferString(tt.body))
request := createConfirmAccountDeletionRequest(tt.body)
recorder := httptest.NewRecorder()
handler.ConfirmAccountDeletion(recorder, request)
@@ -1338,9 +1336,7 @@ func TestAuthHandler_ConcurrentAccess(t *testing.T) {
for i := 0; i < concurrency; i++ {
go func() {
body := bytes.NewBufferString(`{"username":"testuser","password":"Password123!"}`)
req := httptest.NewRequest("POST", "/api/auth/login", body)
req.Header.Set("Content-Type", "application/json")
req := createLoginRequest(`{"username":"testuser","password":"Password123!"}`)
w := httptest.NewRecorder()
handler.Login(w, req)
@@ -1370,8 +1366,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
}, nil
}
body := bytes.NewBufferString(`{"refresh_token":"valid_refresh_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":"valid_refresh_token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1381,8 +1376,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
})
t.Run("Invalid_Request_Body", func(t *testing.T) {
body := bytes.NewBufferString(`invalid json`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`invalid json`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1392,8 +1386,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
})
t.Run("Missing_Refresh_Token", func(t *testing.T) {
body := bytes.NewBufferString(`{"refresh_token":""}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":""}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1407,8 +1400,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, services.ErrRefreshTokenExpired
}
body := bytes.NewBufferString(`{"refresh_token":"expired_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":"expired_token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1422,8 +1414,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, services.ErrRefreshTokenInvalid
}
body := bytes.NewBufferString(`{"refresh_token":"invalid_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":"invalid_token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1437,8 +1428,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, services.ErrAccountLocked
}
body := bytes.NewBufferString(`{"refresh_token":"locked_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":"locked_token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1452,8 +1442,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, fmt.Errorf("internal error")
}
body := bytes.NewBufferString(`{"refresh_token":"error_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":"error_token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1473,8 +1462,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
return nil
}
body := bytes.NewBufferString(`{"refresh_token":"token_to_revoke"}`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req := createRevokeTokenRequest(`{"refresh_token":"token_to_revoke"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1484,8 +1472,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
})
t.Run("Invalid_Request_Body", func(t *testing.T) {
body := bytes.NewBufferString(`invalid json`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req := createRevokeTokenRequest(`invalid json`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1495,8 +1482,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
})
t.Run("Missing_Refresh_Token", func(t *testing.T) {
body := bytes.NewBufferString(`{"refresh_token":""}`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req := createRevokeTokenRequest(`{"refresh_token":""}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1510,8 +1496,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
return fmt.Errorf("revoke failed")
}
body := bytes.NewBufferString(`{"refresh_token":"token"}`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req := createRevokeTokenRequest(`{"refresh_token":"token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"net/url"
"reflect"
"strconv"
"strings"
"time"
@@ -290,3 +291,24 @@ func HandleServiceError(w http.ResponseWriter, err error, defaultMsg string, def
SendErrorResponse(w, defaultMsg, defaultCode)
return false
}
func GetValidatedDTO[T any](r *http.Request) (*T, bool) {
dtoVal := middleware.GetValidatedDTOFromContext(r.Context())
if dtoVal == nil {
return nil, false
}
dto, ok := dtoVal.(*T)
return dto, ok
}
func WithValidation[T any](validationMiddleware func(http.Handler) http.Handler, handler http.HandlerFunc) http.HandlerFunc {
if validationMiddleware == nil {
return handler
}
var zero T
dtoType := reflect.TypeOf(zero)
return func(w http.ResponseWriter, r *http.Request) {
ctx := middleware.SetDTOTypeInContext(r.Context(), dtoType)
validationMiddleware(handler).ServeHTTP(w, r.WithContext(ctx))
}
}

View File

@@ -1,6 +1,7 @@
package handlers
import (
"bytes"
"context"
"encoding/json"
"errors"
@@ -11,6 +12,7 @@ import (
"testing"
"goyco/internal/database"
"goyco/internal/dto"
"goyco/internal/middleware"
"goyco/internal/services"
"goyco/internal/testutils"
@@ -721,6 +723,74 @@ func TestDecodeJSONRequest(t *testing.T) {
}
}
func createRequestWithDTO[T any](method, url string, body []byte) *http.Request {
r := httptest.NewRequest(method, url, bytes.NewReader(body))
var dto T
if len(body) > 0 {
if err := json.NewDecoder(bytes.NewReader(body)).Decode(&dto); err != nil {
return r
}
}
ctx := middleware.SetValidatedDTOInContext(r.Context(), &dto)
return r.WithContext(ctx)
}
func createLoginRequest(body string) *http.Request {
return createRequestWithDTO[dto.LoginRequest](http.MethodPost, "/api/auth/login", []byte(body))
}
func createRegisterRequest(body string) *http.Request {
return createRequestWithDTO[dto.RegisterRequest](http.MethodPost, "/api/auth/register", []byte(body))
}
func createResendVerificationRequest(body string) *http.Request {
return createRequestWithDTO[dto.ResendVerificationRequest](http.MethodPost, "/api/auth/resend-verification", []byte(body))
}
func createForgotPasswordRequest(body string) *http.Request {
return createRequestWithDTO[dto.ForgotPasswordRequest](http.MethodPost, "/api/auth/forgot-password", []byte(body))
}
func createResetPasswordRequest(body string) *http.Request {
return createRequestWithDTO[dto.ResetPasswordRequest](http.MethodPost, "/api/auth/reset-password", []byte(body))
}
func createUpdateEmailRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdateEmailRequest](http.MethodPut, "/api/auth/email", []byte(body))
}
func createUpdateUsernameRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdateUsernameRequest](http.MethodPut, "/api/auth/username", []byte(body))
}
func createUpdatePasswordRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdatePasswordRequest](http.MethodPut, "/api/auth/password", []byte(body))
}
func createConfirmAccountDeletionRequest(body string) *http.Request {
return createRequestWithDTO[dto.ConfirmAccountDeletionRequest](http.MethodPost, "/api/auth/account/confirm", []byte(body))
}
func createRefreshTokenRequest(body string) *http.Request {
return createRequestWithDTO[dto.RefreshTokenRequest](http.MethodPost, "/api/auth/refresh", []byte(body))
}
func createRevokeTokenRequest(body string) *http.Request {
return createRequestWithDTO[dto.RevokeTokenRequest](http.MethodPost, "/api/auth/revoke", []byte(body))
}
func createCreatePostRequest(body string) *http.Request {
return createRequestWithDTO[dto.CreatePostRequest](http.MethodPost, "/api/posts", []byte(body))
}
func createUpdatePostRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdatePostRequest](http.MethodPut, "/api/posts/1", []byte(body))
}
func createVoteRequest(body string) *http.Request {
return createRequestWithDTO[dto.CastVoteRequest](http.MethodPost, "/api/posts/1/vote", []byte(body))
}
func TestParsePagination(t *testing.T) {
tests := []struct {
name string

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"html/template"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
@@ -877,7 +878,8 @@ func (h *PageHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
func (h *PageHandler) Settings(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r)
if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther)
redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return
}
@@ -897,7 +899,8 @@ func (h *PageHandler) Settings(w http.ResponseWriter, r *http.Request) {
func (h *PageHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r)
if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther)
redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return
}
@@ -960,13 +963,15 @@ func (h *PageHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
}
h.clearAuthCookie(w, r)
http.Redirect(w, r, "/login?flash=Email updated. Check your inbox to confirm the new address. You will need to sign in again after verification.", http.StatusSeeOther)
redirectURL := "/login?flash=" + url.QueryEscape("Email updated. Check your inbox to confirm the new address. You will need to sign in again after verification.")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
func (h *PageHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r)
if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther)
redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return
}
@@ -1022,13 +1027,15 @@ func (h *PageHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
return
}
http.Redirect(w, r, "/settings?flash=Username updated successfully.", http.StatusSeeOther)
redirectURL := "/settings?flash=" + url.QueryEscape("Username updated successfully.")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
func (h *PageHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r)
if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther)
redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return
}
@@ -1140,13 +1147,15 @@ func (h *PageHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) {
return
}
http.Redirect(w, r, "/settings?flash=Password updated successfully.", http.StatusSeeOther)
redirectURL := "/settings?flash=" + url.QueryEscape("Password updated successfully.")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
func (h *PageHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r)
if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther)
redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return
}
@@ -1204,7 +1213,8 @@ func (h *PageHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
return
}
http.Redirect(w, r, "/settings?flash=Check your inbox for a confirmation link to finish deleting your account.", http.StatusSeeOther)
redirectURL := "/settings?flash=" + url.QueryEscape("Check your inbox for a confirmation link to finish deleting your account.")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
func (h *PageHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Request) {
@@ -1328,7 +1338,8 @@ func (h *PageHandler) clearAuthCookie(w http.ResponseWriter, r *http.Request) {
func (h *PageHandler) Vote(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r)
if user == nil {
http.Redirect(w, r, "/login?flash=Please sign in to vote", http.StatusSeeOther)
redirectURL := "/login?flash=" + url.QueryEscape("Please sign in to vote")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return
}

View File

@@ -36,11 +36,6 @@ func NewPostHandler(postRepo repositories.PostRepository, titleFetcher services.
type PostResponse = CommonResponse
type UpdatePostRequest struct {
Title string `json:"title"`
Content string `json:"content"`
}
// @Summary Get posts
// @Description Get a list of posts with pagination. Posts include vote statistics (up_votes, down_votes, score) and current user's vote status.
// @Tags posts
@@ -111,7 +106,7 @@ func (h *PostHandler) GetPost(w http.ResponseWriter, r *http.Request) {
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body CreatePostRequest true "Post data"
// @Param request body dto.CreatePostRequest true "Post data"
// @Success 201 {object} PostResponse
// @Failure 400 {object} PostResponse "Invalid request data or validation failed"
// @Failure 401 {object} PostResponse "Authentication required"
@@ -120,32 +115,9 @@ func (h *PostHandler) GetPost(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} PostResponse "Internal server error"
// @Router /api/posts [post]
func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
var req struct {
Title string `json:"title"`
URL string `json:"url"`
Content string `json:"content"`
}
if !DecodeJSONRequest(w, r, &req) {
return
}
req.Title = security.SanitizeInput(req.Title)
req.URL = security.SanitizeURL(req.URL)
req.Content = security.SanitizePostContent(req.Content)
if req.URL == "" {
SendErrorResponse(w, "URL is required", http.StatusBadRequest)
return
}
if len(req.Title) > 200 {
SendErrorResponse(w, "Title must be no more than 200 characters", http.StatusBadRequest)
return
}
if len(req.Content) > 10000 {
SendErrorResponse(w, "Content must be no more than 10,000 characters", http.StatusBadRequest)
req, ok := GetValidatedDTO[dto.CreatePostRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -154,13 +126,20 @@ func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
return
}
title := req.Title
title := security.SanitizeInput(req.Title)
url := security.SanitizeURL(req.URL)
content := security.SanitizePostContent(req.Content)
if url == "" {
SendErrorResponse(w, "Invalid URL", http.StatusBadRequest)
return
}
if title == "" && h.titleFetcher != nil {
titleCtx, cancel := context.WithTimeout(r.Context(), 7*time.Second)
defer cancel()
fetchedTitle, err := h.titleFetcher.FetchTitle(titleCtx, req.URL)
fetchedTitle, err := h.titleFetcher.FetchTitle(titleCtx, url)
if err != nil {
switch {
case errors.Is(err, services.ErrUnsupportedScheme):
@@ -186,10 +165,20 @@ func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
return
}
if len(title) > 200 {
SendErrorResponse(w, "Title must be at most 200 characters", http.StatusBadRequest)
return
}
if len(content) > 10000 {
SendErrorResponse(w, "Content must be at most 10000 characters", http.StatusBadRequest)
return
}
post := &database.Post{
Title: title,
URL: req.URL,
Content: req.Content,
URL: url,
Content: content,
AuthorID: &userID,
}
@@ -257,7 +246,7 @@ func (h *PostHandler) SearchPosts(w http.ResponseWriter, r *http.Request) {
// @Produce json
// @Security BearerAuth
// @Param id path int true "Post ID"
// @Param request body UpdatePostRequest true "Post update data"
// @Param request body dto.UpdatePostRequest true "Post update data"
// @Success 200 {object} PostResponse "Post updated successfully"
// @Failure 400 {object} PostResponse "Invalid request data or validation failed"
// @Failure 401 {object} PostResponse "Authentication required"
@@ -286,40 +275,27 @@ func (h *PostHandler) UpdatePost(w http.ResponseWriter, r *http.Request) {
return
}
var req struct {
Title string `json:"title"`
Content string `json:"content"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.UpdatePostRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
req.Title = security.SanitizeInput(req.Title)
req.Content = security.SanitizePostContent(req.Content)
title := security.SanitizeInput(req.Title)
content := security.SanitizePostContent(req.Content)
if len(req.Title) > 200 {
SendErrorResponse(w, "Title must be no more than 200 characters", http.StatusBadRequest)
return
}
if len(req.Content) > 10000 {
SendErrorResponse(w, "Content must be no more than 10,000 characters", http.StatusBadRequest)
return
}
if err := validation.ValidateTitle(req.Title); err != nil {
if err := validation.ValidateTitle(title); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
if err := validation.ValidateContent(req.Content); err != nil {
if err := validation.ValidateContent(content); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
post.Title = req.Title
post.Content = req.Content
post.Title = title
post.Content = content
if err := h.postRepo.Update(post); err != nil {
SendErrorResponse(w, "Failed to update post", http.StatusInternalServerError)
@@ -458,7 +434,7 @@ func (h *PostHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
if config.GeneralRateLimit != nil {
protected = config.GeneralRateLimit(protected)
}
protected.Post("/posts", h.CreatePost)
protected.Put("/posts/{id}", h.UpdatePost)
protected.Post("/posts", WithValidation[dto.CreatePostRequest](config.ValidationMiddleware, h.CreatePost))
protected.Put("/posts/{id}", WithValidation[dto.UpdatePostRequest](config.ValidationMiddleware, h.UpdatePost))
protected.Delete("/posts/{id}", h.DeletePost)
}

View File

@@ -69,9 +69,8 @@ func TestPostHandlerCreatePostWithTitleFetcher(t *testing.T) {
handler := NewPostHandler(repo, titleFetcher, nil)
request := httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"url":"https://example.com","content":"Test content"}`))
request := createCreatePostRequest(`{"url":"https://example.com","content":"Test content"}`)
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.CreatePost(recorder, request)
@@ -171,7 +170,7 @@ func TestPostHandlerUpdatePostUnauthorized(t *testing.T) {
handler := NewPostHandler(repo, nil, nil)
request := httptest.NewRequest(http.MethodPut, "/api/posts/1", bytes.NewBufferString(`{"title":"Updated Title","content":"Updated content"}`))
request := createUpdatePostRequest(`{"title":"Updated Title","content":"Updated content"}`)
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
request.Header.Set("Content-Type", "application/json")
@@ -278,8 +277,7 @@ func TestPostHandlerCreatePostSuccess(t *testing.T) {
handler := NewPostHandler(repo, fetcher, nil)
body := bytes.NewBufferString(`{"title":" ","url":"https://example.com","content":"Go"}`)
request := httptest.NewRequest(http.MethodPost, "/api/posts", body)
request := createCreatePostRequest(`{"title":" ","url":"https://example.com","content":"Go"}`)
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(42))
request = request.WithContext(ctx)
@@ -297,7 +295,7 @@ func TestPostHandlerCreatePostValidation(t *testing.T) {
handler := NewPostHandler(testutils.NewPostRepositoryStub(), &testutils.TitleFetcherStub{}, nil)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"title":"","url":"","content":""}`))
request := createCreatePostRequest(`{"title":"","url":"","content":""}`)
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
handler.CreatePost(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
@@ -305,14 +303,14 @@ func TestPostHandlerCreatePostValidation(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`invalid json`))
request = createCreatePostRequest(`invalid json`)
handler.CreatePost(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid JSON, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"title":"ok","url":"https://example.com"}`))
request = createCreatePostRequest(`{"title":"ok","url":"https://example.com"}`)
handler.CreatePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
}
@@ -336,8 +334,7 @@ func TestPostHandlerCreatePostTitleFetcherErrors(t *testing.T) {
return "", tc.err
}}
handler := NewPostHandler(repo, fetcher, nil)
body := bytes.NewBufferString(`{"title":" ","url":"https://example.com"}`)
request := httptest.NewRequest(http.MethodPost, "/api/posts", body)
request := createCreatePostRequest(`{"title":" ","url":"https://example.com"}`)
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder()
@@ -495,7 +492,7 @@ func TestPostHandlerUpdatePost(t *testing.T) {
}
handler := NewPostHandler(repo, &testutils.TitleFetcherStub{}, nil)
request := httptest.NewRequest(http.MethodPut, "/api/posts/"+tt.postID, bytes.NewBufferString(tt.requestBody))
request := createUpdatePostRequest(tt.requestBody)
if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx)
@@ -669,6 +666,9 @@ func (e *errorVoteRepository) Delete(uint) error { ret
func (e *errorVoteRepository) Count() (int64, error) { return 0, nil }
func (e *errorVoteRepository) CountByPostID(uint) (int64, error) { return 0, nil }
func (e *errorVoteRepository) CountByUserID(uint) (int64, error) { return 0, nil }
func (e *errorVoteRepository) GetVoteCountsByPostID(uint) (int, int, error) {
return 0, 0, errors.New("database error")
}
func (e *errorVoteRepository) WithTx(*gorm.DB) repositories.VoteRepository { return e }
func TestPostHandler_EdgeCases(t *testing.T) {

View File

@@ -18,4 +18,5 @@ type RouteModuleConfig struct {
AuthRateLimit func(chi.Router) chi.Router
CSRFMiddleware func(http.Handler) http.Handler
AuthMiddleware func(http.Handler) http.Handler
ValidationMiddleware func(http.Handler) http.Handler
}

View File

@@ -24,10 +24,6 @@ func TestPostHandler_XSSProtection_Comprehensive(t *testing.T) {
t.Run("XSS_"+payload[:minLen(20, len(payload))], func(t *testing.T) {
repo := &testutils.PostRepositoryStub{
CreateFn: func(post *database.Post) error {
sanitizedTitle := security.SanitizeInput(payload)
if post.Title != sanitizedTitle {
t.Errorf("Expected sanitized title, got %q", post.Title)
}
return nil
},
}
@@ -41,14 +37,46 @@ func TestPostHandler_XSSProtection_Comprehensive(t *testing.T) {
}
body, _ := json.Marshal(postData)
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
request := createCreatePostRequest(string(body))
request.Header.Set("Content-Type", "application/json")
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder()
handler.CreatePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusCreated)
if recorder.Code != http.StatusCreated {
t.Errorf("Expected status %d, got %d. Body: %s", http.StatusCreated, recorder.Code, recorder.Body.String())
return
}
var response CommonResponse
if err := json.NewDecoder(recorder.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if !response.Success {
t.Errorf("Expected successful response, got error: %s", response.Error)
return
}
dataMap, ok := response.Data.(map[string]any)
if !ok {
t.Fatalf("Expected data to be a map, got %T", response.Data)
}
title, ok := dataMap["title"].(string)
if !ok {
t.Fatalf("Expected title to be a string, got %T", dataMap["title"])
}
expectedSanitized := security.SanitizeInput(payload)
if title != expectedSanitized {
t.Errorf("Expected sanitized title %q, got %q", expectedSanitized, title)
}
if title == payload {
t.Errorf("Title was not sanitized - original payload %q matches response %q", payload, title)
}
})
}
}
@@ -123,7 +151,7 @@ func TestPostHandler_InputValidation(t *testing.T) {
}
body, _ := json.Marshal(postData)
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
request := createCreatePostRequest(string(body))
request.Header.Set("Content-Type", "application/json")
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder()
@@ -230,7 +258,7 @@ func TestAuthHandler_PasswordValidation(t *testing.T) {
}
body, _ := json.Marshal(registerData)
request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
request := createRegisterRequest(string(body))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
@@ -290,7 +318,7 @@ func TestAuthHandler_UsernameSanitization(t *testing.T) {
}
body, _ := json.Marshal(registerData)
request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
request := createRegisterRequest(string(body))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()

View File

@@ -91,7 +91,7 @@ func (h *UserHandler) GetUser(w http.ResponseWriter, r *http.Request) {
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body RegisterRequest true "User data"
// @Param request body dto.RegisterRequest true "User data"
// @Success 201 {object} UserResponse "User created successfully"
// @Failure 400 {object} UserResponse "Invalid request data or validation failed"
// @Failure 401 {object} UserResponse "Authentication required"
@@ -99,13 +99,9 @@ func (h *UserHandler) GetUser(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} UserResponse "Internal server error"
// @Router /api/users [post]
func (h *UserHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
var req struct {
Username string `json:"username"`
Email string `json:"email"`
Password string `json:"password"`
}
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.RegisterRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -189,7 +185,7 @@ func (h *UserHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
}
protected.Get("/users", h.GetUsers)
protected.Post("/users", h.CreateUser)
protected.Post("/users", WithValidation[dto.RegisterRequest](config.ValidationMiddleware, h.CreateUser))
protected.Get("/users/{id}", h.GetUser)
protected.Get("/users/{id}/posts", h.GetUserPosts)
}

View File

@@ -1,7 +1,6 @@
package handlers
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
@@ -103,7 +102,7 @@ func TestUserHandlerCreateUser(t *testing.T) {
return nil
}})
request := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"user","email":"user@example.com","password":"Password123!"}`))
request := createRegisterRequest(`{"username":"user","email":"user@example.com","password":"Password123!"}`)
recorder := httptest.NewRecorder()
handler.CreateUser(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusCreated)
@@ -126,14 +125,14 @@ func TestUserHandlerCreateUser(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString("invalid"))
request = createRegisterRequest("invalid")
handler.CreateUser(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid json, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"","email":"","password":""}`))
request = createRegisterRequest(`{"username":"","email":"","password":""}`)
handler.CreateUser(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for missing fields, got %d", recorder.Result().StatusCode)
@@ -144,7 +143,7 @@ func TestUserHandlerCreateUser(t *testing.T) {
}
handler = newUserHandler(repo)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"user","email":"user@example.com","password":"Password123!"}`))
request = createRegisterRequest(`{"username":"user","email":"user@example.com","password":"Password123!"}`)
handler.CreateUser(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
}
@@ -350,7 +349,7 @@ func TestUserHandler_PasswordValidation(t *testing.T) {
handler := NewUserHandler(repo, authService)
requestBody := fmt.Sprintf(`{"username":"testuser","email":"test@example.com","password":"%s"}`, tt.password)
request := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(requestBody))
request := createRegisterRequest(requestBody)
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()

View File

@@ -4,6 +4,7 @@ import (
"net/http"
"goyco/internal/database"
"goyco/internal/dto"
"goyco/internal/services"
"github.com/go-chi/chi/v5"
@@ -39,11 +40,6 @@ func NewVoteHandler(voteService *services.VoteService) *VoteHandler {
}
}
// @Description Vote request with type field. All votes are handled the same way.
type VoteRequest struct {
Type string `json:"type" example:"up" enums:"up,down,none" description:"Vote type: 'up' for upvote, 'down' for downvote, 'none' to remove vote"`
}
type VoteResponse = CommonResponse
// @Summary Cast a vote on a post
@@ -62,7 +58,7 @@ type VoteResponse = CommonResponse
// @Produce json
// @Security BearerAuth
// @Param id path int true "Post ID"
// @Param request body VoteRequest true "Vote data (type: 'up', 'down', or 'none' to remove)"
// @Param request body dto.CastVoteRequest true "Vote data (type: 'up', 'down', or 'none' to remove)"
// @Success 200 {object} VoteResponse "Vote cast successfully with updated post statistics"
// @Failure 401 {object} VoteResponse "Authentication required"
// @Failure 400 {object} VoteResponse "Invalid request data or vote type"
@@ -82,8 +78,9 @@ func (h *VoteHandler) CastVote(w http.ResponseWriter, r *http.Request) {
return
}
var req VoteRequest
if !DecodeJSONRequest(w, r, &req) {
req, ok := GetValidatedDTO[dto.CastVoteRequest](r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -286,7 +283,7 @@ func (h *VoteHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
protected = config.GeneralRateLimit(protected)
}
protected.Post("/posts/{id}/vote", h.CastVote)
protected.Post("/posts/{id}/vote", WithValidation[dto.CastVoteRequest](config.ValidationMiddleware, h.CastVote))
protected.Delete("/posts/{id}/vote", h.RemoveVote)
protected.Get("/posts/{id}/vote", h.GetUserVote)
protected.Get("/posts/{id}/votes", h.GetPostVotes)

View File

@@ -1,7 +1,6 @@
package handlers
import (
"bytes"
"context"
"encoding/json"
"fmt"
@@ -59,13 +58,13 @@ func TestVoteHandlerCastVote(t *testing.T) {
handler := newVoteHandlerWithRepos()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
handler.CastVote(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/abc/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "abc"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -73,7 +72,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`invalid`))
request = createVoteRequest(`invalid`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -83,7 +82,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"maybe"}`))
request = createVoteRequest(`{"type":"maybe"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -93,7 +92,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -101,7 +100,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx)
@@ -111,7 +110,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"none"}`))
request = createVoteRequest(`{"type":"none"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3))
request = request.WithContext(ctx)
@@ -125,7 +124,7 @@ func TestVoteHandlerCastVotePostNotFound(t *testing.T) {
handler, _, posts := newVoteHandlerWithReposRefs()
delete(posts, 1)
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -164,7 +163,7 @@ func TestVoteHandlerRemoveVote(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -202,7 +201,7 @@ func TestVoteHandlerRemoveVotePostNotFound(t *testing.T) {
func TestVoteHandlerRemoveVoteUnexpectedError(t *testing.T) {
handler, voteRepo, _ := newVoteHandlerWithReposRefs()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -257,7 +256,7 @@ func TestVoteHandlerGetUserVote(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -301,7 +300,7 @@ func TestVoteHandlerGetPostVotes(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -311,7 +310,7 @@ func TestVoteHandlerGetPostVotes(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx)
@@ -345,7 +344,7 @@ func TestVoteFlowRegression(t *testing.T) {
postID := "1"
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx)
@@ -363,7 +362,7 @@ func TestVoteFlowRegression(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx)
@@ -373,7 +372,7 @@ func TestVoteFlowRegression(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"none"}`))
request = createVoteRequest(`{"type":"none"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx)
@@ -404,7 +403,7 @@ func TestVoteFlowRegression(t *testing.T) {
postID := "1"
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -414,7 +413,7 @@ func TestVoteFlowRegression(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx)
@@ -424,7 +423,7 @@ func TestVoteFlowRegression(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3))
request = request.WithContext(ctx)
@@ -452,7 +451,7 @@ func TestVoteFlowRegression(t *testing.T) {
t.Run("ErrorHandlingEdgeCases", func(t *testing.T) {
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(``))
request := createVoteRequest(``)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -460,7 +459,7 @@ func TestVoteFlowRegression(t *testing.T) {
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{}`))
request = createVoteRequest(`{}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -470,7 +469,7 @@ func TestVoteFlowRegression(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"invalid"}`))
request = createVoteRequest(`{"type":"invalid"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)

View File

@@ -53,25 +53,25 @@ func TestIntegration_Caching(t *testing.T) {
router := ctx.Router
t.Run("Cache_Hit_On_Repeated_Requests", func(t *testing.T) {
req1 := httptest.NewRequest("GET", "/api/posts", nil)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
firstRequest := httptest.NewRequest("GET", "/api/posts", nil)
firstRecorder := httptest.NewRecorder()
router.ServeHTTP(firstRecorder, firstRequest)
time.Sleep(10 * time.Millisecond)
req2 := httptest.NewRequest("GET", "/api/posts", nil)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
secondRequest := httptest.NewRequest("GET", "/api/posts", nil)
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondRequest)
if rec1.Code != rec2.Code {
if firstRecorder.Code != secondRecorder.Code {
t.Error("Cached responses should have same status code")
}
if rec1.Body.String() != rec2.Body.String() {
if firstRecorder.Body.String() != secondRecorder.Body.String() {
t.Error("Cached responses should have same body")
}
if rec2.Header().Get("X-Cache") != "HIT" {
if secondRecorder.Header().Get("X-Cache") != "HIT" {
t.Log("Cache may not be enabled for this path or response may not be cacheable")
}
})
@@ -80,9 +80,9 @@ func TestIntegration_Caching(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createUserWithCleanup(t, ctx, "cache_post_user", "cache_post@example.com")
req1 := httptest.NewRequest("GET", "/api/posts", nil)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
firstRequest := httptest.NewRequest("GET", "/api/posts", nil)
firstRecorder := httptest.NewRecorder()
router.ServeHTTP(firstRecorder, firstRequest)
time.Sleep(10 * time.Millisecond)
@@ -92,12 +92,12 @@ func TestIntegration_Caching(t *testing.T) {
"content": "Test content",
}
body, _ := json.Marshal(postBody)
req2 := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req2.Header.Set("Content-Type", "application/json")
req2.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
secondRequest := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
secondRequest.Header.Set("Content-Type", "application/json")
secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondRequest)
time.Sleep(10 * time.Millisecond)
@@ -105,17 +105,17 @@ func TestIntegration_Caching(t *testing.T) {
rec3 := httptest.NewRecorder()
router.ServeHTTP(rec3, req3)
if rec1.Body.String() == rec3.Body.String() && rec1.Code == http.StatusOK && rec3.Code == http.StatusOK {
if firstRecorder.Body.String() == rec3.Body.String() && firstRecorder.Code == http.StatusOK && rec3.Code == http.StatusOK {
t.Log("Cache invalidation may not be working or cache may not be enabled")
}
})
t.Run("Cache_Headers_Present", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("GET", "/api/posts", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if rec.Header().Get("Cache-Control") == "" && rec.Header().Get("X-Cache") == "" {
if recorder.Header().Get("Cache-Control") == "" && recorder.Header().Get("X-Cache") == "" {
t.Log("Cache headers may not be present for all responses")
}
})
@@ -126,18 +126,18 @@ func TestIntegration_Caching(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Cache Delete Post", "https://example.com/cache-delete")
req1 := httptest.NewRequest("GET", "/api/posts", nil)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
firstRequest := httptest.NewRequest("GET", "/api/posts", nil)
firstRecorder := httptest.NewRecorder()
router.ServeHTTP(firstRecorder, firstRequest)
time.Sleep(10 * time.Millisecond)
req2 := httptest.NewRequest("DELETE", "/api/posts/"+fmt.Sprintf("%d", post.ID), nil)
req2.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID)
req2 = testutils.WithURLParams(req2, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
secondRequest := httptest.NewRequest("DELETE", "/api/posts/"+fmt.Sprintf("%d", post.ID), nil)
secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
secondRequest = testutils.WithURLParams(secondRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondRequest)
time.Sleep(10 * time.Millisecond)
@@ -145,7 +145,7 @@ func TestIntegration_Caching(t *testing.T) {
rec3 := httptest.NewRecorder()
router.ServeHTTP(rec3, req3)
if rec1.Body.String() == rec3.Body.String() && rec1.Code == http.StatusOK && rec3.Code == http.StatusOK {
if firstRecorder.Body.String() == rec3.Body.String() && firstRecorder.Code == http.StatusOK && rec3.Code == http.StatusOK {
t.Log("Cache invalidation may not be working or cache may not be enabled")
}
})

View File

@@ -1,15 +1,14 @@
package integration
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"goyco/internal/middleware"
"goyco/internal/services"
"goyco/internal/testutils"
)
@@ -20,17 +19,8 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "logout_user", "logout@example.com")
reqBody := map[string]string{}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/logout", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
request := makePostRequest(t, ctx.Router, "/api/auth/logout", map[string]any{}, user, nil)
assertStatus(t, request, http.StatusOK)
})
t.Run("Auth_Revoke_Token_Endpoint", func(t *testing.T) {
@@ -42,52 +32,23 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
t.Fatalf("Failed to login: %v", err)
}
reqBody := map[string]string{
"refresh_token": loginResult.RefreshToken,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/revoke", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
request := makePostRequest(t, ctx.Router, "/api/auth/revoke", map[string]any{"refresh_token": loginResult.RefreshToken}, user, nil)
assertStatus(t, request, http.StatusOK)
})
t.Run("Auth_Revoke_All_Tokens_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "revoke_all_user", "revoke_all@example.com")
reqBody := map[string]string{}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/revoke-all", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
request := makePostRequest(t, ctx.Router, "/api/auth/revoke-all", map[string]any{}, user, nil)
assertStatus(t, request, http.StatusOK)
})
t.Run("Auth_Resend_Verification_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{
"email": "resend@example.com",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/resend-verification", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatusRange(t, rec, http.StatusOK, http.StatusNotFound)
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/resend-verification", map[string]any{"email": "resend@example.com"})
assertStatusRange(t, request, http.StatusOK, http.StatusNotFound)
})
t.Run("Auth_Confirm_Email_Endpoint", func(t *testing.T) {
@@ -99,109 +60,66 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
token = "test-token"
}
req := httptest.NewRequest("GET", "/api/auth/confirm?token="+url.QueryEscape(token), nil)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatusRange(t, rec, http.StatusOK, http.StatusBadRequest)
request := makeGetRequest(t, ctx.Router, "/api/auth/confirm?token="+url.QueryEscape(token))
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Auth_Update_Email_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "update_email_api_user", "update_email_api@example.com")
reqBody := map[string]string{
"email": "newemail@example.com",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/auth/email", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makePutRequest(t, ctx.Router, "/api/auth/email", map[string]any{"email": "newemail@example.com"}, user, nil)
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if email, ok := data["email"].(string); ok && email != "newemail@example.com" {
t.Errorf("Expected email to be updated, got %s", email)
}
}
}
})
t.Run("Auth_Update_Username_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "update_username_api_user", "update_username_api@example.com")
reqBody := map[string]string{
"username": "new_username",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/auth/username", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makePutRequest(t, ctx.Router, "/api/auth/username", map[string]any{"username": "new_username"}, user, nil)
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if username, ok := data["username"].(string); ok && username != "new_username" {
t.Errorf("Expected username to be updated, got %s", username)
}
}
}
})
t.Run("Users_List_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "users_list_user", "users_list@example.com")
req := httptest.NewRequest("GET", "/api/users", nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, ctx.Router, "/api/users", user, nil)
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if _, exists := data["users"]; !exists {
t.Error("Expected users in response")
}
}
}
})
t.Run("Users_Get_By_ID_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "users_get_user", "users_get@example.com")
req := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d", user.User.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if userData, ok := data["user"].(map[string]any); ok {
if id, ok := userData["id"].(float64); ok && uint(id) != user.User.ID {
t.Errorf("Expected user ID %d, got %.0f", user.User.ID, id)
}
}
}
}
})
t.Run("Users_Get_Posts_Endpoint", func(t *testing.T) {
@@ -210,17 +128,10 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "User Posts Test", "https://example.com/user-posts")
req := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d/posts", user.User.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d/posts", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) == 0 {
t.Error("Expected at least one post in response")
@@ -229,35 +140,24 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
t.Error("Expected posts array in response")
}
}
}
})
t.Run("Users_Create_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "users_create_admin", "users_create_admin@example.com")
reqBody := map[string]string{
request := makePostRequest(t, ctx.Router, "/api/users", map[string]any{
"username": "created_user",
"email": "created@example.com",
"password": "SecurePass123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/users", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
}, user, nil)
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusCreated)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
response := assertJSONResponse(t, request, http.StatusCreated)
if data, ok := getDataFromResponse(response); ok {
if _, exists := data["user"]; !exists {
t.Error("Expected user in response")
}
}
}
})
t.Run("Posts_Update_Endpoint", func(t *testing.T) {
@@ -266,30 +166,19 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Update Test Post", "https://example.com/update-test")
reqBody := map[string]string{
request := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{
"title": "Updated Title",
"content": "Updated content",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if postData, ok := data["post"].(map[string]any); ok {
if title, ok := postData["title"].(string); ok && title != "Updated Title" {
t.Errorf("Expected title 'Updated Title', got '%s'", title)
}
}
}
}
})
t.Run("Posts_Delete_Endpoint", func(t *testing.T) {
@@ -298,20 +187,11 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Delete Test Post", "https://example.com/delete-test")
req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
request := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, request, http.StatusOK)
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil)
getRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRec, getReq)
assertStatus(t, getRec, http.StatusNotFound)
getRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID))
assertStatus(t, getRequest, http.StatusNotFound)
})
t.Run("Votes_Get_All_Endpoint", func(t *testing.T) {
@@ -319,28 +199,11 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "votes_get_all_user", "votes_get_all@example.com")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Votes Test Post", "https://example.com/votes-test")
makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/votes", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteBody := map[string]string{"type": "up"}
voteBodyBytes, _ := json.Marshal(voteBody)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBodyBytes))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
req := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if votes, ok := data["votes"].([]any); ok {
if len(votes) == 0 {
t.Error("Expected at least one vote in response")
@@ -349,7 +212,6 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
t.Error("Expected votes array in response")
}
}
}
})
t.Run("Votes_Remove_Endpoint", func(t *testing.T) {
@@ -358,49 +220,461 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Remove Test", "https://example.com/vote-remove")
voteBody := map[string]string{"type": "up"}
voteBodyBytes, _ := json.Marshal(voteBody)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBodyBytes))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
request := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, request, http.StatusOK)
})
t.Run("API_Info_Endpoint", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, ctx.Router, "/api")
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if _, exists := data["endpoints"]; !exists {
t.Error("Expected endpoints in API info")
}
}
}
})
t.Run("Swagger_Documentation_Endpoint", func(t *testing.T) {
req := httptest.NewRequest("GET", "/swagger/index.html", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, ctx.Router, "/swagger/index.html")
assertStatusRange(t, request, http.StatusOK, http.StatusNotFound)
})
ctx.Router.ServeHTTP(rec, req)
t.Run("Search_Endpoint_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "search_edge"), uniqueTestEmail(t, "search_edge"))
assertStatusRange(t, rec, http.StatusOK, http.StatusNotFound)
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Searchable Post One", "https://example.com/one")
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Searchable Post Two", "https://example.com/two")
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Different Content", "https://example.com/three")
t.Run("Empty_Search_Results", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q=nonexistentterm12345")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) != 0 {
t.Errorf("Expected empty search results, got %d posts", len(posts))
}
}
if count, ok := data["count"].(float64); ok && count != 0 {
t.Errorf("Expected count 0, got %.0f", count)
}
}
})
t.Run("Search_With_Pagination", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q=Searchable&limit=1&offset=0")
response := assertJSONResponse(t, request, http.StatusOK)
var firstPostID any
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) > 1 {
t.Errorf("Expected at most 1 post with limit=1, got %d", len(posts))
}
if len(posts) > 0 {
if post, ok := posts[0].(map[string]any); ok {
firstPostID = post["id"]
}
}
}
if limit, ok := data["limit"].(float64); ok && limit != 1 {
t.Errorf("Expected limit 1 in response, got %.0f", limit)
}
if offset, ok := data["offset"].(float64); ok && offset != 0 {
t.Errorf("Expected offset 0 in response, got %.0f", offset)
}
}
secondRequest := makeGetRequest(t, ctx.Router, "/api/posts/search?q=Searchable&limit=1&offset=1")
secondResponse := assertJSONResponse(t, secondRequest, http.StatusOK)
if data, ok := getDataFromResponse(secondResponse); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) > 1 {
t.Errorf("Expected at most 1 post with limit=1 and offset=1, got %d", len(posts))
}
if len(posts) > 0 && firstPostID != nil {
if post, ok := posts[0].(map[string]any); ok {
if post["id"] == firstPostID {
t.Error("Expected different post with offset=1, got same post as offset=0")
}
}
}
}
}
})
t.Run("Search_With_Special_Characters", func(t *testing.T) {
specialQueries := []string{
"Searchable%20Post",
"Searchable'Post",
"Searchable\"Post",
"Searchable;Post",
"Searchable--Post",
}
for _, query := range specialQueries {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q="+url.QueryEscape(query))
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
}
})
t.Run("Search_Empty_Query", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q=")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) != 0 {
t.Errorf("Expected empty results for empty query, got %d posts", len(posts))
}
}
if count, ok := data["count"].(float64); ok && count != 0 {
t.Errorf("Expected count 0 for empty query, got %.0f", count)
}
}
})
t.Run("Search_With_Very_Long_Query", func(t *testing.T) {
longQuery := strings.Repeat("a", 1000)
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q="+url.QueryEscape(longQuery))
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Search_Case_Insensitive", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q=SEARCHABLE")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) == 0 {
t.Error("Expected case-insensitive search to find posts")
}
}
}
})
})
t.Run("Title_Fetch_Endpoint_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
t.Run("Missing_URL_Parameter", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/title")
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Empty_URL_Parameter", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url=")
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Invalid_URL_Format", func(t *testing.T) {
invalidURLs := []string{
"not-a-url",
"://invalid",
"http://",
"https://",
}
for _, invalidURL := range invalidURLs {
ctx.Suite.TitleFetcher.SetError(services.ErrUnsupportedScheme)
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url="+url.QueryEscape(invalidURL))
assertErrorResponse(t, request, http.StatusBadRequest)
}
})
t.Run("Unsupported_URL_Schemes", func(t *testing.T) {
unsupportedSchemes := []string{
"ftp://example.com",
"file:///etc/passwd",
"javascript:alert(1)",
"data:text/html,<script>alert(1)</script>",
}
for _, schemeURL := range unsupportedSchemes {
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url="+url.QueryEscape(schemeURL))
assertErrorResponse(t, request, http.StatusBadRequest)
}
})
t.Run("SSRF_Protection_Localhost", func(t *testing.T) {
ssrfURLs := []string{
"http://localhost",
"http://127.0.0.1",
"http://127.0.0.1:8080",
"http://[::1]",
"http://0.0.0.0",
}
for _, ssrfURL := range ssrfURLs {
ctx.Suite.TitleFetcher.SetError(services.ErrSSRFBlocked)
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url="+url.QueryEscape(ssrfURL))
assertStatusRange(t, request, http.StatusBadRequest, http.StatusBadGateway)
}
})
t.Run("SSRF_Protection_Private_IPs", func(t *testing.T) {
privateIPs := []string{
"http://192.168.1.1",
"http://10.0.0.1",
"http://172.16.0.1",
}
for _, privateIP := range privateIPs {
ctx.Suite.TitleFetcher.SetError(services.ErrSSRFBlocked)
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url="+url.QueryEscape(privateIP))
assertStatusRange(t, request, http.StatusBadRequest, http.StatusBadGateway)
}
})
t.Run("Title_Fetch_Error_Handling", func(t *testing.T) {
ctx.Suite.TitleFetcher.SetError(services.ErrTitleNotFound)
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url=https://example.com/notitle")
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Valid_URL_Success", func(t *testing.T) {
ctx.Suite.TitleFetcher.SetTitle("Valid Title")
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url=https://example.com/valid")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if title, ok := data["title"].(string); ok {
if title != "Valid Title" {
t.Errorf("Expected title 'Valid Title', got '%s'", title)
}
} else {
t.Error("Expected title in response data")
}
}
})
})
t.Run("Get_User_Vote_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "vote_edge"), uniqueTestEmail(t, "vote_edge"))
secondUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "vote_edge2"), uniqueTestEmail(t, "vote_edge2"))
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Edge Test Post", "https://example.com/vote-edge")
t.Run("Get_Vote_When_User_Has_Voted", func(t *testing.T) {
voteRequest := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, voteRequest, http.StatusOK)
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if hasVote, ok := data["has_vote"].(bool); !ok || !hasVote {
t.Error("Expected has_vote to be true when user has voted")
}
if vote, ok := data["vote"]; !ok || vote == nil {
t.Error("Expected vote object when user has voted")
}
}
})
t.Run("Get_Vote_When_User_Has_Not_Voted", func(t *testing.T) {
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if hasVote, ok := data["has_vote"].(bool); ok {
if hasVote {
t.Error("Expected has_vote to be false when user has not voted")
}
} else {
t.Error("Expected has_vote field in response")
}
}
})
t.Run("Get_Vote_Invalid_Post_ID", func(t *testing.T) {
request := makeAuthenticatedGetRequest(t, ctx.Router, "/api/posts/999999/vote", user, map[string]string{"id": "999999"})
if request.Code != http.StatusOK && request.Code != http.StatusNotFound {
t.Errorf("Expected status 200 or 404 for invalid post ID, got %d", request.Code)
}
})
t.Run("Get_Vote_Unauthenticated", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID))
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Get_Vote_Response_Structure", func(t *testing.T) {
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
response := assertJSONResponse(t, request, http.StatusOK)
if success, ok := response["success"].(bool); !ok || !success {
t.Error("Expected success field to be true")
}
if data, ok := getDataFromResponse(response); ok {
if _, exists := data["has_vote"]; !exists {
t.Error("Expected has_vote field in response data")
}
if _, exists := data["is_anonymous"]; !exists {
t.Error("Expected is_anonymous field in response data")
}
} else {
t.Error("Expected data field in response")
}
})
})
t.Run("Refresh_Token_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
refreshUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "refresh_edge"), uniqueTestEmail(t, "refresh_edge"))
t.Run("Refresh_With_Expired_Token", func(t *testing.T) {
loginResult, err := ctx.AuthService.Login(refreshUser.User.Username, "SecurePass123!")
if err != nil {
t.Fatalf("Failed to login: %v", err)
}
refreshToken, err := ctx.Suite.RefreshTokenRepo.GetByTokenHash(testutils.HashVerificationToken(loginResult.RefreshToken))
if err != nil {
t.Fatalf("Failed to get refresh token: %v", err)
}
refreshToken.ExpiresAt = time.Now().Add(-1 * time.Hour)
if err := ctx.Suite.DB.Model(refreshToken).Update("expires_at", refreshToken.ExpiresAt).Error; err != nil {
t.Fatalf("Failed to expire refresh token: %v", err)
}
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": loginResult.RefreshToken})
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Refresh_With_Revoked_Token", func(t *testing.T) {
loginResult, err := ctx.AuthService.Login(refreshUser.User.Username, "SecurePass123!")
if err != nil {
t.Fatalf("Failed to login: %v", err)
}
if err := ctx.AuthService.RevokeRefreshToken(loginResult.RefreshToken); err != nil {
t.Fatalf("Failed to revoke refresh token: %v", err)
}
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": loginResult.RefreshToken})
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Refresh_With_Empty_Token", func(t *testing.T) {
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": ""})
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Refresh_With_Missing_Token_Field", func(t *testing.T) {
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{})
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Refresh_Token_Rotation", func(t *testing.T) {
loginResult, err := ctx.AuthService.Login(refreshUser.User.Username, "SecurePass123!")
if err != nil {
t.Fatalf("Failed to login: %v", err)
}
originalRefreshToken := loginResult.RefreshToken
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": originalRefreshToken})
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if newAccessToken, ok := data["access_token"].(string); ok {
if newAccessToken == "" {
t.Error("Expected new access token in refresh response")
}
if newRefreshToken, ok := data["refresh_token"].(string); ok {
if newRefreshToken != "" && newRefreshToken == originalRefreshToken {
t.Log("Refresh token rotation may not be implemented (same token returned)")
}
}
}
}
})
t.Run("Refresh_After_Account_Lock", func(t *testing.T) {
lockedUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "refresh_lock"), uniqueTestEmail(t, "refresh_lock"))
loginResult, err := ctx.AuthService.Login(lockedUser.User.Username, "SecurePass123!")
if err != nil {
t.Fatalf("Failed to login: %v", err)
}
lockedUser.User.Locked = true
if err := ctx.Suite.UserRepo.Update(lockedUser.User); err != nil {
t.Fatalf("Failed to lock user: %v", err)
}
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": loginResult.RefreshToken})
assertStatusRange(t, request, http.StatusUnauthorized, http.StatusForbidden)
})
t.Run("Refresh_With_Invalid_Token_Format", func(t *testing.T) {
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": "invalid-token-format-12345"})
assertErrorResponse(t, request, http.StatusUnauthorized)
})
})
t.Run("Pagination_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
paginationUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "pagination_edge"), uniqueTestEmail(t, "pagination_edge"))
for i := 0; i < 5; i++ {
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, paginationUser.User.ID, fmt.Sprintf("Pagination Post %d", i), fmt.Sprintf("https://example.com/pag%d", i))
}
t.Run("Negative_Limit", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts?limit=-1")
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Negative_Offset", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts?offset=-1")
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Very_Large_Limit", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts?limit=10000")
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Very_Large_Offset", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts?offset=10000")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) > 0 {
t.Logf("Large offset returned %d posts (may be expected)", len(posts))
}
}
}
})
t.Run("Invalid_Pagination_Parameters", func(t *testing.T) {
invalidParams := []string{
"limit=abc",
"offset=xyz",
"limit=",
"offset=",
}
for _, param := range invalidParams {
request := makeGetRequest(t, ctx.Router, "/api/posts?"+param)
assertStatus(t, request, http.StatusOK)
}
})
})
}

View File

@@ -1,17 +1,12 @@
package integration
import (
"bytes"
"compress/gzip"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"goyco/internal/middleware"
"goyco/internal/testutils"
)
func TestIntegration_Compression(t *testing.T) {
@@ -19,16 +14,16 @@ func TestIntegration_Compression(t *testing.T) {
router := ctx.Router
t.Run("Response_Compression_Gzip", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil)
req.Header.Set("Accept-Encoding", "gzip")
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/api/posts", nil)
request.Header.Set("Accept-Encoding", "gzip")
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
contentEncoding := rec.Header().Get("Content-Encoding")
contentEncoding := recorder.Header().Get("Content-Encoding")
if contentEncoding != "" && strings.Contains(contentEncoding, "gzip") {
assertHeaderContains(t, rec, "Content-Encoding", "gzip")
reader, err := gzip.NewReader(rec.Body)
assertHeaderContains(t, recorder, "Content-Encoding", "gzip")
reader, err := gzip.NewReader(recorder.Body)
if err != nil {
t.Fatalf("Failed to create gzip reader: %v", err)
}
@@ -48,14 +43,14 @@ func TestIntegration_Compression(t *testing.T) {
})
t.Run("Compression_Headers_Present", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil)
req.Header.Set("Accept-Encoding", "gzip, deflate")
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/api/posts", nil)
request.Header.Set("Accept-Encoding", "gzip, deflate")
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Header().Get("Vary") != "" {
assertHeaderContains(t, rec, "Vary", "Accept-Encoding")
if recorder.Header().Get("Vary") != "" {
assertHeaderContains(t, recorder, "Vary", "Accept-Encoding")
} else {
t.Log("Vary header may not always be present")
}
@@ -67,25 +62,19 @@ func TestIntegration_StaticFiles(t *testing.T) {
router := ctx.Router
t.Run("Robots_Txt_Served", func(t *testing.T) {
req := httptest.NewRequest("GET", "/robots.txt", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/robots.txt")
router.ServeHTTP(rec, req)
assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
if !strings.Contains(rec.Body.String(), "User-agent") {
if !strings.Contains(request.Body.String(), "User-agent") {
t.Error("Expected robots.txt content")
}
})
t.Run("Static_Files_Security_Headers", func(t *testing.T) {
req := httptest.NewRequest("GET", "/robots.txt", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/robots.txt")
router.ServeHTTP(rec, req)
if rec.Header().Get("X-Content-Type-Options") == "" {
if request.Header().Get("X-Content-Type-Options") == "" {
t.Log("Security headers may not be applied to all static files")
}
})
@@ -101,32 +90,22 @@ func TestIntegration_URLMetadata(t *testing.T) {
ctx.Suite.TitleFetcher.SetTitle("Fetched Title")
postBody := map[string]string{
postBody := map[string]any{
"title": "Test Post",
"url": "https://example.com/metadata-test",
"content": "Test content",
}
body, _ := json.Marshal(postBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makePostRequest(t, router, "/api/posts", postBody, user, nil)
router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusCreated)
assertStatus(t, request, http.StatusCreated)
})
t.Run("URL_Metadata_Endpoint", func(t *testing.T) {
ctx.Suite.TitleFetcher.SetTitle("Endpoint Title")
req := httptest.NewRequest("GET", "/api/posts/title?url=https://example.com/test", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/api/posts/title?url=https://example.com/test")
router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, request, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["title"]; !exists {

View File

@@ -1,14 +1,10 @@
package integration
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"goyco/internal/middleware"
"goyco/internal/testutils"
)
@@ -22,33 +18,19 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, owner.User.ID, "Owner Post", "https://example.com/owner")
updateBody := map[string]string{
request := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{
"title": "Updated Title",
"content": "Updated content",
}
body, _ := json.Marshal(updateBody)
}, otherUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+otherUser.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, otherUser.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
assertErrorResponse(t, request, http.StatusForbidden)
ctx.Router.ServeHTTP(rec, req)
request = makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{
"title": "Updated Title",
"content": "Updated content",
}, owner, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertErrorResponse(t, rec, http.StatusForbidden)
req = httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+owner.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, owner.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
})
t.Run("Post_Delete_Authorization", func(t *testing.T) {
@@ -58,47 +40,27 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, owner.User.ID, "Delete Post", "https://example.com/delete")
req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil)
req.Header.Set("Authorization", "Bearer "+otherUser.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, otherUser.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
request := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), otherUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, request, http.StatusForbidden)
assertErrorResponse(t, rec, http.StatusForbidden)
request = makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), owner, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req = httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil)
req.Header.Set("Authorization", "Bearer "+owner.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, owner.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
})
t.Run("User_Profile_Access_Authorization", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user1 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "profile_user1", "profile_user1@example.com")
user2 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "profile_user2", "profile_user2@example.com")
firstUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "profile_user1", "profile_user1@example.com")
secondUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "profile_user2", "profile_user2@example.com")
req := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d", user1.User.ID), nil)
req.Header.Set("Authorization", "Bearer "+user2.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user2.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", user1.User.ID)})
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d", firstUser.User.ID), secondUser, map[string]string{"id": fmt.Sprintf("%d", firstUser.User.ID)})
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if userData, ok := data["user"].(map[string]any); ok {
if id, ok := userData["id"].(float64); ok && uint(id) != user1.User.ID {
t.Errorf("Expected user ID %d, got %.0f", user1.User.ID, id)
}
if id, ok := userData["id"].(float64); ok && uint(id) != firstUser.User.ID {
t.Errorf("Expected user ID %d, got %.0f", firstUser.User.ID, id)
}
}
}
@@ -109,24 +71,13 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "settings_auth_user", "settings_auth@example.com")
otherUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "settings_auth_other", "settings_auth_other@example.com")
updateBody := map[string]string{
"email": "newemail@example.com",
}
body, _ := json.Marshal(updateBody)
request := makePutRequest(t, ctx.Router, "/api/auth/email", map[string]any{"email": "newemail@example.com"}, otherUser, nil)
req := httptest.NewRequest("PUT", "/api/auth/email", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+otherUser.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, otherUser.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, request, http.StatusOK)
if response == nil {
return
}
if data, ok := response["data"].(map[string]any); ok {
if data, ok := getDataFromResponse(response); ok {
if userData, ok := data["user"].(map[string]any); ok {
if email, ok := userData["email"].(string); ok && email == "newemail@example.com" {
if id, ok := userData["id"].(float64); ok && uint(id) != otherUser.User.ID {
@@ -136,20 +87,9 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
}
}
updateBody2 := map[string]string{
"email": "anothernewemail@example.com",
}
body2, _ := json.Marshal(updateBody2)
request = makePutRequest(t, ctx.Router, "/api/auth/email", map[string]any{"email": "anothernewemail@example.com"}, user, nil)
req = httptest.NewRequest("PUT", "/api/auth/email", bytes.NewBuffer(body2))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
})
t.Run("Vote_Authorization", func(t *testing.T) {
@@ -159,73 +99,42 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, postOwner.User.ID, "Vote Auth Post", "https://example.com/vote-auth")
voteBody := map[string]string{"type": "up"}
body, _ := json.Marshal(voteBody)
request := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
assertStatus(t, request, http.StatusOK)
ctx.Router.ServeHTTP(rec, req)
request = makePostRequestWithJSON(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"})
assertStatus(t, rec, http.StatusOK)
req = httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Protected_Endpoint_Without_Auth", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("{}")))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, ctx.Router, "/api/posts", map[string]any{})
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Protected_Endpoint_With_Invalid_Token", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("{}")))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer invalid-token")
rec := httptest.NewRecorder()
request := makeRequest(t, ctx.Router, "POST", "/api/posts", []byte("{}"), map[string]string{"Content-Type": "application/json", "Authorization": "Bearer invalid-token"})
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("User_List_Authorization", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "list_auth_user", "list_auth@example.com")
req := httptest.NewRequest("GET", "/api/users", nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, ctx.Router, "/api/users", user, nil)
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
request = makeGetRequest(t, ctx.Router, "/api/users")
req = httptest.NewRequest("GET", "/api/users", nil)
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Refresh_Token_Authorization", func(t *testing.T) {
@@ -237,18 +146,9 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
t.Fatalf("Failed to login: %v", err)
}
refreshBody := map[string]string{
"refresh_token": loginResult.RefreshToken,
}
body, _ := json.Marshal(refreshBody)
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": loginResult.RefreshToken})
req := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, request, http.StatusOK)
if response == nil {
return
}
@@ -260,17 +160,8 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
t.Error("Expected data field in refresh response")
}
refreshBody = map[string]string{
"refresh_token": "invalid-refresh-token",
}
body, _ = json.Marshal(refreshBody)
request = makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": "invalid-refresh-token"})
req = httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, request, http.StatusUnauthorized)
})
}

View File

@@ -14,166 +14,137 @@ func TestIntegration_CSRF_Protection(t *testing.T) {
ctx := setupPageHandlerTestContext(t)
router := ctx.Router
t.Run("CSRF_Blocks_Form_Without_Token", func(t *testing.T) {
reqBody := url.Values{}
reqBody.Set("username", "testuser")
reqBody.Set("email", "test@example.com")
reqBody.Set("password", "SecurePass123!")
getCSRFToken := func(t *testing.T, path string, cookies ...*http.Cookie) *http.Cookie {
t.Helper()
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d. Body: %s", rec.Code, rec.Body.String())
request := httptest.NewRequest("GET", path, nil)
for _, c := range cookies {
request.AddCookie(c)
}
if !strings.Contains(rec.Body.String(), "Invalid CSRF token") {
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
for _, cookie := range recorder.Result().Cookies() {
if cookie.Name == "csrf_token" {
return cookie
}
}
t.Fatalf("Expected CSRF cookie to be set for %s", path)
return nil
}
t.Run("CSRF_Blocks_Form_Without_Token", func(t *testing.T) {
requestBody := url.Values{}
requestBody.Set("username", "testuser")
requestBody.Set("email", "test@example.com")
requestBody.Set("password", "SecurePass123!")
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d. Body: %s", recorder.Code, recorder.Body.String())
}
if !strings.Contains(recorder.Body.String(), "Invalid CSRF token") {
t.Error("Expected CSRF error message")
}
})
t.Run("CSRF_Allows_Form_With_Valid_Token", func(t *testing.T) {
getReq := httptest.NewRequest("GET", "/register", nil)
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
csrfCookie := getCSRFToken(t, "/register")
cookies := getRec.Result().Cookies()
var csrfCookie *http.Cookie
for _, cookie := range cookies {
if cookie.Name == "csrf_token" {
csrfCookie = cookie
break
}
}
requestBody := url.Values{}
requestBody.Set("username", "csrf_user")
requestBody.Set("email", "csrf@example.com")
requestBody.Set("password", "SecurePass123!")
requestBody.Set("csrf_token", csrfCookie.Value)
if csrfCookie == nil {
t.Fatal("Expected CSRF cookie to be set")
}
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
csrfToken := csrfCookie.Value
router.ServeHTTP(recorder, request)
reqBody := url.Values{}
reqBody.Set("username", "csrf_user")
reqBody.Set("email", "csrf@example.com")
reqBody.Set("password", "SecurePass123!")
reqBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code == http.StatusForbidden {
if recorder.Code == http.StatusForbidden {
t.Error("Expected form submission with valid CSRF token to succeed")
}
})
t.Run("CSRF_Allows_API_Requests", func(t *testing.T) {
reqBody := map[string]string{
requestBody := map[string]string{
"username": "api_user",
"email": "api@example.com",
"password": "SecurePass123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
body, _ := json.Marshal(requestBody)
request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Code == http.StatusForbidden {
if recorder.Code == http.StatusForbidden {
t.Error("Expected API requests to bypass CSRF protection")
}
})
t.Run("CSRF_Blocks_Mismatched_Token", func(t *testing.T) {
getReq := httptest.NewRequest("GET", "/register", nil)
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
csrfCookie := getCSRFToken(t, "/register")
cookies := getRec.Result().Cookies()
var csrfCookie *http.Cookie
for _, cookie := range cookies {
if cookie.Name == "csrf_token" {
csrfCookie = cookie
break
requestBody := url.Values{}
requestBody.Set("username", "mismatch_user")
requestBody.Set("email", "mismatch@example.com")
requestBody.Set("password", "SecurePass123!")
requestBody.Set("csrf_token", "wrong-token")
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d. Body: %s", recorder.Code, recorder.Body.String())
}
}
if csrfCookie == nil {
t.Fatal("Expected CSRF cookie to be set")
}
reqBody := url.Values{}
reqBody.Set("username", "mismatch_user")
reqBody.Set("email", "mismatch@example.com")
reqBody.Set("password", "SecurePass123!")
reqBody.Set("csrf_token", "wrong-token")
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d. Body: %s", rec.Code, rec.Body.String())
}
if !strings.Contains(rec.Body.String(), "Invalid CSRF token") {
if !strings.Contains(recorder.Body.String(), "Invalid CSRF token") {
t.Error("Expected CSRF error message")
}
})
t.Run("CSRF_Allows_GET_Requests", func(t *testing.T) {
req := httptest.NewRequest("GET", "/register", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/register", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Code == http.StatusForbidden {
if recorder.Code == http.StatusForbidden {
t.Error("Expected GET requests to bypass CSRF protection")
}
})
t.Run("CSRF_Token_In_Header", func(t *testing.T) {
getReq := httptest.NewRequest("GET", "/register", nil)
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
csrfCookie := getCSRFToken(t, "/register")
cookies := getRec.Result().Cookies()
var csrfCookie *http.Cookie
for _, cookie := range cookies {
if cookie.Name == "csrf_token" {
csrfCookie = cookie
break
}
}
requestBody := url.Values{}
requestBody.Set("username", "header_user")
requestBody.Set("email", "header@example.com")
requestBody.Set("password", "SecurePass123!")
if csrfCookie == nil {
t.Fatal("Expected CSRF cookie to be set")
}
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.Header.Set("X-CSRF-Token", csrfCookie.Value)
request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
csrfToken := csrfCookie.Value
router.ServeHTTP(recorder, request)
reqBody := url.Values{}
reqBody.Set("username", "header_user")
reqBody.Set("email", "header@example.com")
reqBody.Set("password", "SecurePass123!")
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("X-CSRF-Token", csrfToken)
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code == http.StatusForbidden {
if recorder.Code == http.StatusForbidden {
t.Error("Expected CSRF token in header to be accepted")
}
})
@@ -182,41 +153,24 @@ func TestIntegration_CSRF_Protection(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createUserWithCleanup(t, ctx, "csrf_form_user", "csrf_form@example.com")
getReq := httptest.NewRequest("GET", "/posts/new", nil)
getReq.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
authCookie := &http.Cookie{Name: "auth_token", Value: user.Token}
csrfCookie := getCSRFToken(t, "/posts/new", authCookie)
cookies := getRec.Result().Cookies()
var csrfCookie *http.Cookie
for _, cookie := range cookies {
if cookie.Name == "csrf_token" {
csrfCookie = cookie
break
}
}
requestBody := url.Values{}
requestBody.Set("title", "CSRF Test Post")
requestBody.Set("url", "https://example.com/csrf-test")
requestBody.Set("content", "Test content")
requestBody.Set("csrf_token", csrfCookie.Value)
if csrfCookie == nil {
t.Fatal("Expected CSRF cookie to be set")
}
request := httptest.NewRequest("POST", "/posts", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(authCookie)
request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
csrfToken := csrfCookie.Value
router.ServeHTTP(recorder, request)
reqBody := url.Values{}
reqBody.Set("title", "CSRF Test Post")
reqBody.Set("url", "https://example.com/csrf-test")
reqBody.Set("content", "Test content")
reqBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/posts", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code == http.StatusForbidden {
if recorder.Code == http.StatusForbidden {
t.Error("Expected post creation with valid CSRF token to succeed")
}
})

View File

@@ -19,27 +19,18 @@ func TestIntegration_DataConsistency(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "consistency_user", "consistency@example.com")
postBody := map[string]string{
request := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "Consistency Test Post",
"url": "https://example.com/consistency",
"content": "Test content",
}
body, _ := json.Marshal(postBody)
}, user, nil)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
createResponse := assertJSONResponse(t, rec, http.StatusCreated)
createResponse := assertJSONResponse(t, request, http.StatusCreated)
if createResponse == nil {
return
}
postData, ok := createResponse["data"].(map[string]any)
postData, ok := getDataFromResponse(createResponse)
if !ok {
t.Fatal("Response missing data")
}
@@ -53,16 +44,14 @@ func TestIntegration_DataConsistency(t *testing.T) {
createdURL := postData["url"]
createdContent := postData["content"]
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%.0f", postID), nil)
getRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRec, getReq)
getRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%.0f", postID))
getResponse := assertJSONResponse(t, getRec, http.StatusOK)
getResponse := assertJSONResponse(t, getRequest, http.StatusOK)
if getResponse == nil {
return
}
getPostData, ok := getResponse["data"].(map[string]any)
getPostData, ok := getDataFromResponse(getResponse)
if !ok {
t.Fatal("Get response missing data")
}
@@ -96,32 +85,17 @@ func TestIntegration_DataConsistency(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Consistency Post", "https://example.com/vote-consistency")
voteBody := map[string]string{"type": "up"}
body, _ := json.Marshal(voteBody)
voteRequest := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, voteRequest, http.StatusOK)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
getVotesRequest := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/votes", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, voteRec, http.StatusOK)
getVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesReq.Header.Set("Authorization", "Bearer "+user.Token)
getVotesReq = testutils.WithUserContext(getVotesReq, middleware.UserIDKey, user.User.ID)
getVotesReq = testutils.WithURLParams(getVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRec, getVotesReq)
votesResponse := assertJSONResponse(t, getVotesRec, http.StatusOK)
votesResponse := assertJSONResponse(t, getVotesRequest, http.StatusOK)
if votesResponse == nil {
return
}
votesData, ok := votesResponse["data"].(map[string]any)
votesData, ok := getDataFromResponse(votesResponse)
if !ok {
t.Fatal("Votes response missing data")
}
@@ -172,32 +146,21 @@ func TestIntegration_DataConsistency(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Original Title", "https://example.com/original")
updateBody := map[string]string{
updateRequest := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{
"title": "Updated Title",
"content": "Updated content",
}
body, _ := json.Marshal(updateBody)
}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
updateReq := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body))
updateReq.Header.Set("Content-Type", "application/json")
updateReq.Header.Set("Authorization", "Bearer "+user.Token)
updateReq = testutils.WithUserContext(updateReq, middleware.UserIDKey, user.User.ID)
updateReq = testutils.WithURLParams(updateReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
updateRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(updateRec, updateReq)
assertStatus(t, updateRequest, http.StatusOK)
assertStatus(t, updateRec, http.StatusOK)
getRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID))
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil)
getRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRec, getReq)
getResponse := assertJSONResponse(t, getRec, http.StatusOK)
getResponse := assertJSONResponse(t, getRequest, http.StatusOK)
if getResponse == nil {
return
}
getPostData, ok := getResponse["data"].(map[string]any)
getPostData, ok := getDataFromResponse(getResponse)
if !ok {
t.Fatal("Get response missing data")
}
@@ -215,18 +178,12 @@ func TestIntegration_DataConsistency(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "user_posts_consistency", "user_posts_consistency@example.com")
post1 := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Post 1", "https://example.com/post1")
post2 := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Post 2", "https://example.com/post2")
firstPost := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Post 1", "https://example.com/post1")
secondPost := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Post 2", "https://example.com/post2")
req := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d/posts", user.User.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d/posts", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, request, http.StatusOK)
if response == nil {
return
}
@@ -245,26 +202,26 @@ func TestIntegration_DataConsistency(t *testing.T) {
t.Errorf("Expected at least 2 posts, got %d", len(posts))
}
foundPost1 := false
foundPost2 := false
foundFirstPost := false
foundSecondPost := false
for _, post := range posts {
if postMap, ok := post.(map[string]any); ok {
if postID, ok := postMap["id"].(float64); ok {
if uint(postID) == post1.ID {
foundPost1 = true
if uint(postID) == firstPost.ID {
foundFirstPost = true
}
if uint(postID) == post2.ID {
foundPost2 = true
if uint(postID) == secondPost.ID {
foundSecondPost = true
}
}
}
}
if !foundPost1 {
if !foundFirstPost {
t.Error("Post 1 not found in user posts")
}
if !foundPost2 {
if !foundSecondPost {
t.Error("Post 2 not found in user posts")
}
})
@@ -275,20 +232,20 @@ func TestIntegration_DataConsistency(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Delete Consistency Post", "https://example.com/delete-consistency")
deleteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil)
deleteReq.Header.Set("Authorization", "Bearer "+user.Token)
deleteReq = testutils.WithUserContext(deleteReq, middleware.UserIDKey, user.User.ID)
deleteReq = testutils.WithURLParams(deleteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
deleteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(deleteRec, deleteReq)
deleteRequest := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil)
deleteRequest.Header.Set("Authorization", "Bearer "+user.Token)
deleteRequest = testutils.WithUserContext(deleteRequest, middleware.UserIDKey, user.User.ID)
deleteRequest = testutils.WithURLParams(deleteRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
deleteRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(deleteRecorder, deleteRequest)
assertStatus(t, deleteRec, http.StatusOK)
assertStatus(t, deleteRecorder, http.StatusOK)
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil)
getRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRec, getReq)
getRequest := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil)
getRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRecorder, getRequest)
assertStatus(t, getRec, http.StatusNotFound)
assertStatus(t, getRecorder, http.StatusNotFound)
})
t.Run("Vote_Removal_Consistency", func(t *testing.T) {
@@ -300,33 +257,33 @@ func TestIntegration_DataConsistency(t *testing.T) {
voteBody := map[string]string{"type": "up"}
body, _ := json.Marshal(voteBody)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
voteRequest := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
voteRequest.Header.Set("Content-Type", "application/json")
voteRequest.Header.Set("Authorization", "Bearer "+user.Token)
voteRequest = testutils.WithUserContext(voteRequest, middleware.UserIDKey, user.User.ID)
voteRequest = testutils.WithURLParams(voteRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRecorder, voteRequest)
assertStatus(t, voteRec, http.StatusOK)
assertStatus(t, voteRecorder, http.StatusOK)
removeVoteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
removeVoteReq.Header.Set("Authorization", "Bearer "+user.Token)
removeVoteReq = testutils.WithUserContext(removeVoteReq, middleware.UserIDKey, user.User.ID)
removeVoteReq = testutils.WithURLParams(removeVoteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
removeVoteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(removeVoteRec, removeVoteReq)
removeVoteRequest := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
removeVoteRequest.Header.Set("Authorization", "Bearer "+user.Token)
removeVoteRequest = testutils.WithUserContext(removeVoteRequest, middleware.UserIDKey, user.User.ID)
removeVoteRequest = testutils.WithURLParams(removeVoteRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
removeVoteRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(removeVoteRecorder, removeVoteRequest)
assertStatus(t, removeVoteRec, http.StatusOK)
assertStatus(t, removeVoteRecorder, http.StatusOK)
getVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesReq.Header.Set("Authorization", "Bearer "+user.Token)
getVotesReq = testutils.WithUserContext(getVotesReq, middleware.UserIDKey, user.User.ID)
getVotesReq = testutils.WithURLParams(getVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRec, getVotesReq)
getVotesRequest := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesRequest.Header.Set("Authorization", "Bearer "+user.Token)
getVotesRequest = testutils.WithUserContext(getVotesRequest, middleware.UserIDKey, user.User.ID)
getVotesRequest = testutils.WithURLParams(getVotesRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRecorder, getVotesRequest)
votesResponse := assertJSONResponse(t, getVotesRec, http.StatusOK)
votesResponse := assertJSONResponse(t, getVotesRecorder, http.StatusOK)
if votesResponse == nil {
return
}

View File

@@ -22,41 +22,39 @@ func TestIntegration_EdgeCases(t *testing.T) {
expiredToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2MDAwMDAwMDB9.expired"
req := httptest.NewRequest("GET", "/api/auth/me", nil)
req.Header.Set("Authorization", "Bearer "+expiredToken)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/api/auth/me", nil)
request.Header.Set("Authorization", "Bearer "+expiredToken)
recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, recorder, http.StatusUnauthorized)
})
t.Run("Concurrent_Vote_Operations", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user1 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "vote_user1", "vote1@example.com")
firstUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "vote_user1", "vote1@example.com")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user1.User.ID, "Concurrent Vote Post", "https://example.com/concurrent")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, firstUser.User.ID, "Concurrent Vote Post", "https://example.com/concurrent")
var wg sync.WaitGroup
errors := make(chan error, 10)
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for range 5 {
wg.Go(func() {
voteBody := map[string]string{"type": "up"}
body, _ := json.Marshal(voteBody)
req := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user1.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user1.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
errors <- fmt.Errorf("unexpected status: %d", rec.Code)
request := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+firstUser.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, firstUser.User.ID)
request = testutils.WithURLParams(request, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(recorder, request)
if recorder.Code != http.StatusOK {
errors <- fmt.Errorf("unexpected status: %d", recorder.Code)
}
}()
})
}
wg.Wait()
@@ -72,8 +70,8 @@ func TestIntegration_EdgeCases(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "large_user", "large@example.com")
largeContent := make([]byte, 10001)
for i := range largeContent {
largeContent[i] = 'a'
for idx := range largeContent {
largeContent[idx] = 'a'
}
postBody := map[string]string{
@@ -82,36 +80,36 @@ func TestIntegration_EdgeCases(t *testing.T) {
"content": string(largeContent),
}
body, _ := json.Marshal(postBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusBadRequest)
assertErrorResponse(t, recorder, http.StatusBadRequest)
smallContent := make([]byte, 1000)
for i := range smallContent {
smallContent[i] = 'a'
for idx := range smallContent {
smallContent[idx] = 'a'
}
postBody2 := map[string]string{
secondPostBody := map[string]string{
"title": "Small Post",
"url": "https://example.com/small",
"content": string(smallContent),
}
body2, _ := json.Marshal(postBody2)
req2 := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body2))
req2.Header.Set("Content-Type", "application/json")
req2.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID)
rec2 := httptest.NewRecorder()
secondBody, _ := json.Marshal(secondPostBody)
secondRequest := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(secondBody))
secondRequest.Header.Set("Content-Type", "application/json")
secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
secondRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec2, req2)
ctx.Router.ServeHTTP(secondRecorder, secondRequest)
assertStatus(t, rec2, http.StatusCreated)
assertStatus(t, secondRecorder, http.StatusCreated)
})
t.Run("Malformed_JSON_Payloads", func(t *testing.T) {
@@ -127,15 +125,15 @@ func TestIntegration_EdgeCases(t *testing.T) {
}
for _, payload := range malformedPayloads {
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBufferString(payload))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBufferString(payload))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusBadRequest)
assertErrorResponse(t, recorder, http.StatusBadRequest)
}
})
@@ -148,38 +146,36 @@ func TestIntegration_EdgeCases(t *testing.T) {
voteBody := map[string]string{"type": "up"}
body, _ := json.Marshal(voteBody)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
assertStatus(t, voteRec, http.StatusOK)
voteRequest := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
voteRequest.Header.Set("Content-Type", "application/json")
voteRequest.Header.Set("Authorization", "Bearer "+user.Token)
voteRequest = testutils.WithUserContext(voteRequest, middleware.UserIDKey, user.User.ID)
voteRequest = testutils.WithURLParams(voteRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRecorder, voteRequest)
assertStatus(t, voteRecorder, http.StatusOK)
var wg sync.WaitGroup
for i := 0; i < 3; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
}()
for range 3 {
wg.Go(func() {
request := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
request = testutils.WithURLParams(request, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(recorder, request)
})
}
wg.Wait()
getVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesReq.Header.Set("Authorization", "Bearer "+user.Token)
getVotesReq = testutils.WithUserContext(getVotesReq, middleware.UserIDKey, user.User.ID)
getVotesReq = testutils.WithURLParams(getVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRec, getVotesReq)
getVotesRequest := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesRequest.Header.Set("Authorization", "Bearer "+user.Token)
getVotesRequest = testutils.WithUserContext(getVotesRequest, middleware.UserIDKey, user.User.ID)
getVotesRequest = testutils.WithURLParams(getVotesRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRecorder, getVotesRequest)
votesResponse := assertJSONResponse(t, getVotesRec, http.StatusOK)
votesResponse := assertJSONResponse(t, getVotesRecorder, http.StatusOK)
if votesResponse != nil {
if data, ok := votesResponse["data"].(map[string]any); ok {
if votes, ok := data["votes"].([]any); ok {

View File

@@ -1,14 +1,10 @@
package integration
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"goyco/internal/database"
"goyco/internal/middleware"
"goyco/internal/testutils"
)
@@ -19,19 +15,14 @@ func TestIntegration_EmailService(t *testing.T) {
t.Run("Registration_Email_Sent", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{
reqBody := map[string]any{
"username": "email_reg_user",
"email": "email_reg@example.com",
"password": "SecurePass123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, router, "/api/auth/register", reqBody)
router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusCreated)
assertStatus(t, request, http.StatusCreated)
token := ctx.Suite.EmailSender.VerificationToken()
if token == "" {
@@ -52,15 +43,10 @@ func TestIntegration_EmailService(t *testing.T) {
t.Fatalf("Failed to create user: %v", err)
}
reqBody := map[string]string{
reqBody := map[string]any{
"username_or_email": "email_reset_user",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/forgot-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
makePostRequestWithJSON(t, router, "/api/auth/forgot-password", reqBody)
token := ctx.Suite.EmailSender.PasswordResetToken()
if token == "" {
@@ -72,17 +58,9 @@ func TestIntegration_EmailService(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "email_del_user", "email_del@example.com")
reqBody := map[string]string{}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makeDeleteRequest(t, router, "/api/auth/account", user, nil)
router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
token := ctx.Suite.EmailSender.DeletionToken()
if token == "" {
@@ -94,17 +72,10 @@ func TestIntegration_EmailService(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "email_change_user", "email_change@example.com")
reqBody := map[string]string{
reqBody := map[string]any{
"email": "newemail@example.com",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/auth/email", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
makePutRequest(t, router, "/api/auth/email", reqBody, user, nil)
token := ctx.Suite.EmailSender.VerificationToken()
if token == "" {
@@ -115,17 +86,12 @@ func TestIntegration_EmailService(t *testing.T) {
t.Run("Email_Template_Content", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{
reqBody := map[string]any{
"username": "template_user",
"email": "template@example.com",
"password": "SecurePass123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
makePostRequestWithJSON(t, router, "/api/auth/register", reqBody)
token := ctx.Suite.EmailSender.VerificationToken()
if token == "" {

View File

@@ -1,7 +1,6 @@
package integration
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
@@ -10,7 +9,7 @@ import (
"strings"
"testing"
"goyco/internal/middleware"
"goyco/internal/database"
"goyco/internal/testutils"
)
@@ -20,46 +19,34 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Run("Complete_Registration_To_Post_Creation_Journey", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
registerBody := map[string]string{
registerRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "journey_user",
"email": "journey@example.com",
"password": "SecurePass123!",
}
body, _ := json.Marshal(registerBody)
registerReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
registerReq.Header.Set("Content-Type", "application/json")
registerRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(registerRec, registerReq)
})
assertStatus(t, registerRec, http.StatusCreated)
assertStatus(t, registerRequest, http.StatusCreated)
verificationToken := ctx.Suite.EmailSender.VerificationToken()
if verificationToken == "" {
t.Fatal("Verification token not sent")
}
confirmReq := httptest.NewRequest("GET", "/api/auth/confirm?token="+url.QueryEscape(verificationToken), nil)
confirmRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(confirmRec, confirmReq)
confirmRequest := makeGetRequest(t, ctx.Router, "/api/auth/confirm?token="+url.QueryEscape(verificationToken))
assertStatus(t, confirmRec, http.StatusOK)
assertStatus(t, confirmRequest, http.StatusOK)
loginBody := map[string]string{
loginRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/login", map[string]any{
"username": "journey_user",
"password": "SecurePass123!",
}
loginBodyBytes, _ := json.Marshal(loginBody)
loginReq := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBuffer(loginBodyBytes))
loginReq.Header.Set("Content-Type", "application/json")
loginRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(loginRec, loginReq)
})
loginResponse := assertJSONResponse(t, loginRec, http.StatusOK)
loginResponse := assertJSONResponse(t, loginRequest, http.StatusOK)
if loginResponse == nil {
return
}
data, ok := loginResponse["data"].(map[string]any)
data, ok := getDataFromResponse(loginResponse)
if !ok {
t.Fatal("Login response missing data")
}
@@ -67,8 +54,8 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
var token string
if accessToken, ok := data["access_token"].(string); ok && accessToken != "" {
token = accessToken
} else if tokenVal, ok := data["token"].(string); ok && tokenVal != "" {
token = tokenVal
} else if tokenValue, ok := data["token"].(string); ok && tokenValue != "" {
token = tokenValue
} else {
t.Fatal("Login response missing access_token or token")
}
@@ -90,25 +77,19 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Fatalf("Login response missing user.id. Data: %+v", data)
}
postBody := map[string]string{
postBodyBytes, _ := json.Marshal(map[string]any{
"title": "Journey Test Post",
"url": "https://example.com/journey",
"content": "Test content",
}
postBodyBytes, _ := json.Marshal(postBody)
postReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(postBodyBytes))
postReq.Header.Set("Content-Type", "application/json")
postReq.Header.Set("Authorization", "Bearer "+token)
postReq = testutils.WithUserContext(postReq, middleware.UserIDKey, uint(userID))
postRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(postRec, postReq)
})
postRequest := makeAuthenticatedRequest(t, ctx.Router, "POST", "/api/posts", postBodyBytes, &authenticatedUser{User: &database.User{ID: uint(userID)}, Token: token}, nil)
postResponse := assertJSONResponse(t, postRec, http.StatusCreated)
postResponse := assertJSONResponse(t, postRequest, http.StatusCreated)
if postResponse == nil {
return
}
postData, ok := postResponse["data"].(map[string]any)
postData, ok := getDataFromResponse(postResponse)
if !ok {
t.Fatal("Post response missing data")
}
@@ -118,56 +99,39 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Fatal("Post response missing id")
}
getPostReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%.0f", postID), nil)
getPostRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getPostRec, getPostReq)
getPostRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%.0f", postID))
assertStatus(t, getPostRec, http.StatusOK)
assertStatus(t, getPostRequest, http.StatusOK)
})
t.Run("Complete_Password_Reset_Journey", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "reset_journey_user", "reset_journey@example.com")
resetBody := map[string]string{
resetRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/forgot-password", map[string]any{
"username_or_email": "reset_journey@example.com",
}
resetBodyBytes, _ := json.Marshal(resetBody)
resetReq := httptest.NewRequest("POST", "/api/auth/forgot-password", bytes.NewBuffer(resetBodyBytes))
resetReq.Header.Set("Content-Type", "application/json")
resetRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(resetRec, resetReq)
})
assertStatus(t, resetRec, http.StatusOK)
assertStatus(t, resetRequest, http.StatusOK)
resetToken := ctx.Suite.EmailSender.GetLastPasswordResetToken()
if resetToken == "" {
t.Fatal("Password reset token not sent")
}
newPasswordBody := map[string]string{
newPasswordRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/reset-password", map[string]any{
"token": resetToken,
"new_password": "NewSecurePass123!",
}
newPasswordBodyBytes, _ := json.Marshal(newPasswordBody)
newPasswordReq := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(newPasswordBodyBytes))
newPasswordReq.Header.Set("Content-Type", "application/json")
newPasswordRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(newPasswordRec, newPasswordReq)
})
assertStatus(t, newPasswordRec, http.StatusOK)
assertStatus(t, newPasswordRequest, http.StatusOK)
loginBody := map[string]string{
loginRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/login", map[string]any{
"username": "reset_journey_user",
"password": "NewSecurePass123!",
}
loginBodyBytes, _ := json.Marshal(loginBody)
loginReq := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBuffer(loginBodyBytes))
loginReq.Header.Set("Content-Type", "application/json")
loginRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(loginRec, loginReq)
})
assertStatus(t, loginRec, http.StatusOK)
assertStatus(t, loginRequest, http.StatusOK)
})
t.Run("Complete_Vote_And_Unvote_Journey", func(t *testing.T) {
@@ -176,40 +140,21 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Journey Post", "https://example.com/vote-journey")
voteBody := map[string]string{"type": "up"}
voteBodyBytes, _ := json.Marshal(voteBody)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBodyBytes))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
voteRequest := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, voteRequest, http.StatusOK)
assertStatus(t, voteRec, http.StatusOK)
getVotesRequest := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/votes", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesReq.Header.Set("Authorization", "Bearer "+user.Token)
getVotesReq = testutils.WithUserContext(getVotesReq, middleware.UserIDKey, user.User.ID)
getVotesReq = testutils.WithURLParams(getVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRec, getVotesReq)
votesResponse := assertJSONResponse(t, getVotesRec, http.StatusOK)
votesResponse := assertJSONResponse(t, getVotesRequest, http.StatusOK)
if votesResponse == nil {
return
}
if data, ok := votesResponse["data"].(map[string]any); ok {
if data, ok := getDataFromResponse(votesResponse); ok {
if votes, ok := data["votes"].([]any); ok && len(votes) > 0 {
unvoteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
unvoteReq.Header.Set("Authorization", "Bearer "+user.Token)
unvoteReq = testutils.WithUserContext(unvoteReq, middleware.UserIDKey, user.User.ID)
unvoteReq = testutils.WithURLParams(unvoteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
unvoteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(unvoteRec, unvoteReq)
unvoteRequest := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, unvoteRec, http.StatusOK)
assertStatus(t, unvoteRequest, http.StatusOK)
}
}
})
@@ -221,31 +166,31 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
csrfToken := getCSRFToken(t, pageRouter, "/register")
reqBody := url.Values{}
reqBody.Set("username", "page_journey_user")
reqBody.Set("email", "page_journey@example.com")
reqBody.Set("password", "SecurePass123!")
reqBody.Set("password_confirm", "SecurePass123!")
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("username", "page_journey_user")
requestBody.Set("email", "page_journey@example.com")
requestBody.Set("password", "SecurePass123!")
requestBody.Set("password_confirm", "SecurePass123!")
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
pageRouter.ServeHTTP(rec, req)
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
pageRouter.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
verificationToken := pageCtx.Suite.EmailSender.VerificationToken()
if verificationToken == "" {
t.Fatal("Verification token not sent")
}
confirmReq := httptest.NewRequest("GET", "/confirm?token="+url.QueryEscape(verificationToken), nil)
confirmRec := httptest.NewRecorder()
pageRouter.ServeHTTP(confirmRec, confirmReq)
confirmRequest := httptest.NewRequest("GET", "/confirm?token="+url.QueryEscape(verificationToken), nil)
confirmRecorder := httptest.NewRecorder()
pageRouter.ServeHTTP(confirmRecorder, confirmRequest)
assertStatusRange(t, confirmRec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, confirmRecorder, http.StatusOK, http.StatusSeeOther)
loginCSRFToken := getCSRFToken(t, pageRouter, "/login")
@@ -254,15 +199,15 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
loginBody.Set("password", "SecurePass123!")
loginBody.Set("csrf_token", loginCSRFToken)
loginReq := httptest.NewRequest("POST", "/login", strings.NewReader(loginBody.Encode()))
loginReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
loginReq.AddCookie(&http.Cookie{Name: "csrf_token", Value: loginCSRFToken})
loginRec := httptest.NewRecorder()
pageRouter.ServeHTTP(loginRec, loginReq)
loginRequest := httptest.NewRequest("POST", "/login", strings.NewReader(loginBody.Encode()))
loginRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded")
loginRequest.AddCookie(&http.Cookie{Name: "csrf_token", Value: loginCSRFToken})
loginRecorder := httptest.NewRecorder()
pageRouter.ServeHTTP(loginRecorder, loginRequest)
assertStatus(t, loginRec, http.StatusSeeOther)
assertStatus(t, loginRecorder, http.StatusSeeOther)
loginCookies := loginRec.Result().Cookies()
loginCookies := loginRecorder.Result().Cookies()
var authToken string
for _, cookie := range loginCookies {
if cookie.Name == "auth_token" {
@@ -275,37 +220,31 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Fatal("Auth token not set after login")
}
homeReq := httptest.NewRequest("GET", "/", nil)
homeReq.AddCookie(&http.Cookie{Name: "auth_token", Value: authToken})
homeRec := httptest.NewRecorder()
pageRouter.ServeHTTP(homeRec, homeReq)
homeRequest := httptest.NewRequest("GET", "/", nil)
homeRequest.AddCookie(&http.Cookie{Name: "auth_token", Value: authToken})
homeRecorder := httptest.NewRecorder()
pageRouter.ServeHTTP(homeRecorder, homeRequest)
assertStatus(t, homeRec, http.StatusOK)
assertStatus(t, homeRecorder, http.StatusOK)
})
t.Run("Complete_Post_Creation_And_Update_Journey", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "post_update_journey_user", "post_update_journey@example.com")
postBody := map[string]string{
postBodyBytes, _ := json.Marshal(map[string]any{
"title": "Original Title",
"url": "https://example.com/original",
"content": "Original content",
}
postBodyBytes, _ := json.Marshal(postBody)
postReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(postBodyBytes))
postReq.Header.Set("Content-Type", "application/json")
postReq.Header.Set("Authorization", "Bearer "+user.Token)
postReq = testutils.WithUserContext(postReq, middleware.UserIDKey, user.User.ID)
postRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(postRec, postReq)
})
postRequest := makeAuthenticatedRequest(t, ctx.Router, "POST", "/api/posts", postBodyBytes, user, nil)
postResponse := assertJSONResponse(t, postRec, http.StatusCreated)
postResponse := assertJSONResponse(t, postRequest, http.StatusCreated)
if postResponse == nil {
return
}
postData, ok := postResponse["data"].(map[string]any)
postData, ok := getDataFromResponse(postResponse)
if !ok {
t.Fatal("Post response missing data")
}
@@ -315,34 +254,25 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Fatal("Post response missing id")
}
updateBody := map[string]string{
updateBodyBytes, _ := json.Marshal(map[string]any{
"title": "Updated Title",
"content": "Updated content",
}
updateBodyBytes, _ := json.Marshal(updateBody)
updateReq := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%.0f", postID), bytes.NewBuffer(updateBodyBytes))
updateReq.Header.Set("Content-Type", "application/json")
updateReq.Header.Set("Authorization", "Bearer "+user.Token)
updateReq = testutils.WithUserContext(updateReq, middleware.UserIDKey, user.User.ID)
updateReq = testutils.WithURLParams(updateReq, map[string]string{"id": fmt.Sprintf("%.0f", postID)})
updateRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(updateRec, updateReq)
})
updateRequest := makeAuthenticatedRequest(t, ctx.Router, "PUT", fmt.Sprintf("/api/posts/%.0f", postID), updateBodyBytes, user, map[string]string{"id": fmt.Sprintf("%.0f", postID)})
updateResponse := assertJSONResponse(t, updateRec, http.StatusOK)
updateResponse := assertJSONResponse(t, updateRequest, http.StatusOK)
if updateResponse == nil {
return
}
getPostReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%.0f", postID), nil)
getPostRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getPostRec, getPostReq)
getPostRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%.0f", postID))
getPostResponse := assertJSONResponse(t, getPostRec, http.StatusOK)
getPostResponse := assertJSONResponse(t, getPostRequest, http.StatusOK)
if getPostResponse == nil {
return
}
if data, ok := getPostResponse["data"].(map[string]any); ok {
if data, ok := getDataFromResponse(getPostResponse); ok {
if post, ok := data["post"].(map[string]any); ok {
if title, ok := post["title"].(string); ok && title != "Updated Title" {
t.Errorf("Post title not updated: expected 'Updated Title', got '%s'", title)

View File

@@ -2,7 +2,6 @@ package integration
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
@@ -19,58 +18,46 @@ func TestIntegration_ErrorPropagation(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "json_error_user", "json_error@example.com")
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("invalid json{")))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("invalid json{")))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusBadRequest)
assertErrorResponse(t, recorder, http.StatusBadRequest)
})
t.Run("Validation_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{
reqBody := map[string]any{
"username": "",
"email": "invalid-email",
"password": "weak",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", reqBody)
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusBadRequest)
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Database_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "db_error_user", "db_error@example.com")
reqBody := map[string]string{
reqBody := map[string]any{
"title": "Test Post",
"url": "https://example.com/test",
"content": "Test content",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makePostRequest(t, ctx.Router, "/api/posts", reqBody, user, nil)
ctx.Router.ServeHTTP(rec, req)
if rec.Code == http.StatusInternalServerError {
assertErrorResponse(t, rec, http.StatusInternalServerError)
if request.Code == http.StatusInternalServerError {
assertErrorResponse(t, request, http.StatusInternalServerError)
} else {
assertStatus(t, rec, http.StatusCreated)
assertStatus(t, request, http.StatusCreated)
}
})
@@ -78,34 +65,23 @@ func TestIntegration_ErrorPropagation(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "notfound_error_user", "notfound_error@example.com")
req := httptest.NewRequest("GET", "/api/posts/999999", nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": "999999"})
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, ctx.Router, "/api/posts/999999", user, map[string]string{"id": "999999"})
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusNotFound)
assertErrorResponse(t, request, http.StatusNotFound)
})
t.Run("Unauthorized_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{
reqBody := map[string]any{
"title": "Test Post",
"url": "https://example.com/test",
"content": "Test content",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, ctx.Router, "/api/posts", reqBody)
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Forbidden_Error_Propagation", func(t *testing.T) {
@@ -115,79 +91,59 @@ func TestIntegration_ErrorPropagation(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, owner.User.ID, "Forbidden Post", "https://example.com/forbidden")
updateBody := map[string]string{
updateBody := map[string]any{
"title": "Updated Title",
"content": "Updated content",
}
body, _ := json.Marshal(updateBody)
req := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+otherUser.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, otherUser.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
request := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), updateBody, otherUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusForbidden)
assertErrorResponse(t, request, http.StatusForbidden)
})
t.Run("Service_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{
reqBody := map[string]any{
"username": "existing_user",
"email": "existing@example.com",
"password": "SecurePass123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", reqBody)
assertStatus(t, rec, http.StatusCreated)
assertStatus(t, request, http.StatusCreated)
req = httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
request = makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", reqBody)
assertStatusRange(t, rec, http.StatusBadRequest, http.StatusConflict)
assertErrorResponse(t, rec, rec.Code)
assertStatusRange(t, request, http.StatusBadRequest, http.StatusConflict)
assertErrorResponse(t, request, request.Code)
})
t.Run("Middleware_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("{}")))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer expired.invalid.token")
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("{}")))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer expired.invalid.token")
recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusUnauthorized)
assertErrorResponse(t, recorder, http.StatusUnauthorized)
})
t.Run("Handler_Error_Response_Format", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
req := httptest.NewRequest("GET", "/api/nonexistent", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, ctx.Router, "/api/nonexistent")
ctx.Router.ServeHTTP(rec, req)
if rec.Code == http.StatusNotFound {
if rec.Header().Get("Content-Type") == "application/json" {
assertErrorResponse(t, rec, http.StatusNotFound)
} else {
if rec.Body.Len() == 0 {
if request.Code == http.StatusNotFound {
if request.Header().Get("Content-Type") == "application/json" {
assertErrorResponse(t, request, http.StatusNotFound)
} else if request.Body.Len() == 0 {
t.Error("Expected error response body")
}
}
}
})
}

View File

@@ -11,55 +11,35 @@ import (
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"goyco/internal/database"
"goyco/internal/handlers"
"goyco/internal/middleware"
"goyco/internal/repositories"
"goyco/internal/services"
"goyco/internal/testutils"
"github.com/golang-jwt/jwt/v5"
)
func TestIntegration_Handlers(t *testing.T) {
suite := testutils.NewServiceSuite(t)
authService, err := services.NewAuthFacadeForTest(testutils.AppTestConfig, suite.UserRepo, suite.PostRepo, suite.DeletionRepo, suite.RefreshTokenRepo, suite.EmailSender)
if err != nil {
t.Fatalf("Failed to create auth service: %v", err)
}
voteService := services.NewVoteService(suite.VoteRepo, suite.PostRepo, suite.DB)
emailSender := suite.EmailSender
userRepo := suite.UserRepo
postRepo := suite.PostRepo
titleFetcher := suite.TitleFetcher
authHandler := handlers.NewAuthHandler(authService, userRepo)
postHandler := handlers.NewPostHandler(postRepo, titleFetcher, voteService)
voteHandler := handlers.NewVoteHandler(voteService)
userHandler := handlers.NewUserHandler(userRepo, authService)
ctx := setupTestContext(t)
authService := ctx.AuthService
emailSender := ctx.Suite.EmailSender
userRepo := ctx.Suite.UserRepo
postRepo := ctx.Suite.PostRepo
t.Run("Auth_Handler_Complete_Workflow", func(t *testing.T) {
emailSender.Reset()
registerData := map[string]string{
registerResponse := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "handler_user",
"email": "handler@example.com",
"password": "SecurePass123!",
}
registerBody, _ := json.Marshal(registerData)
registerReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(registerBody))
registerReq.Header.Set("Content-Type", "application/json")
registerResp := httptest.NewRecorder()
authHandler.Register(registerResp, registerReq)
if registerResp.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", registerResp.Code)
})
if registerResponse.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", registerResponse.Code)
}
var registerPayload map[string]any
if err := json.Unmarshal(registerResp.Body.Bytes(), &registerPayload); err != nil {
t.Fatalf("Failed to decode register response: %v", err)
}
registerPayload := assertJSONResponse(t, registerResponse, http.StatusCreated)
if success, _ := registerPayload["success"].(bool); !success {
t.Fatalf("Expected register response success, got %v", registerPayload)
}
@@ -78,11 +58,9 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatalf("Failed to update user with mock token: %v", err)
}
confirmReq := httptest.NewRequest(http.MethodGet, "/api/auth/confirm?token="+url.QueryEscape(mockToken), nil)
confirmResp := httptest.NewRecorder()
authHandler.ConfirmEmail(confirmResp, confirmReq)
if confirmResp.Code != http.StatusOK {
t.Fatalf("Expected 200 when confirming email via handler, got %d", confirmResp.Code)
confirmResponse := makeGetRequest(t, ctx.Router, "/api/auth/confirm?token="+url.QueryEscape(mockToken))
if confirmResponse.Code != http.StatusOK {
t.Fatalf("Expected 200 when confirming email via handler, got %d", confirmResponse.Code)
}
loginSeed := createAuthenticatedUser(t, authService, userRepo, "auth_handler_login", "auth_handler_login@example.com")
@@ -92,92 +70,60 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatalf("Service login failed for seeded user: %v", err)
}
meReq := httptest.NewRequest("GET", "/api/auth/me", nil)
meReq.Header.Set("Authorization", "Bearer "+loginAuth.AccessToken)
meReq = testutils.WithUserContext(meReq, middleware.UserIDKey, loginSeed.User.ID)
meResp := httptest.NewRecorder()
authHandler.Me(meResp, meReq)
if meResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", meResp.Code)
meResponse := makeAuthenticatedGetRequest(t, ctx.Router, "/api/auth/me", &authenticatedUser{User: loginSeed.User, Token: loginAuth.AccessToken}, nil)
if meResponse.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", meResponse.Code)
}
})
t.Run("Auth_Handler_Security_Validation", func(t *testing.T) {
emailSender.Reset()
weakData := map[string]string{
weakResponse := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "weak_user",
"email": "weak@example.com",
"password": "123",
}
weakBody, _ := json.Marshal(weakData)
weakReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(weakBody))
weakReq.Header.Set("Content-Type", "application/json")
weakResp := httptest.NewRecorder()
authHandler.Register(weakResp, weakReq)
if weakResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for weak password, got %d", weakResp.Code)
})
if weakResponse.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for weak password, got %d", weakResponse.Code)
}
var weakErrorResp map[string]any
if err := json.Unmarshal(weakResp.Body.Bytes(), &weakErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := weakErrorResp["success"].(bool); success {
weakErrorResponse := assertJSONResponse(t, weakResponse, http.StatusBadRequest)
if success, _ := weakErrorResponse["success"].(bool); success {
t.Error("Expected error response to have success=false")
}
if errorMsg, ok := weakErrorResp["error"].(string); !ok || errorMsg == "" {
if errorMsg, ok := weakErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain validation error message")
}
invalidData := map[string]string{
invalidResponse := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "invalid_user",
"email": "not-an-email",
"password": "SecurePass123!",
}
invalidBody, _ := json.Marshal(invalidData)
invalidReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(invalidBody))
invalidReq.Header.Set("Content-Type", "application/json")
invalidResp := httptest.NewRecorder()
authHandler.Register(invalidResp, invalidReq)
if invalidResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid email, got %d", invalidResp.Code)
})
if invalidResponse.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid email, got %d", invalidResponse.Code)
}
var invalidEmailErrorResp map[string]any
if err := json.Unmarshal(invalidResp.Body.Bytes(), &invalidEmailErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := invalidEmailErrorResp["success"].(bool); success {
invalidEmailErrorResponse := assertJSONResponse(t, invalidResponse, http.StatusBadRequest)
if success, _ := invalidEmailErrorResponse["success"].(bool); success {
t.Error("Expected error response to have success=false")
}
if errorMsg, ok := invalidEmailErrorResp["error"].(string); !ok || errorMsg == "" {
if errorMsg, ok := invalidEmailErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain validation error message")
}
incompleteData := map[string]string{
incompleteResponse := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "incomplete_user",
}
incompleteBody, _ := json.Marshal(incompleteData)
incompleteReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(incompleteBody))
incompleteReq.Header.Set("Content-Type", "application/json")
incompleteResp := httptest.NewRecorder()
authHandler.Register(incompleteResp, incompleteReq)
if incompleteResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for missing fields, got %d", incompleteResp.Code)
})
if incompleteResponse.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for missing fields, got %d", incompleteResponse.Code)
}
var incompleteErrorResp map[string]any
if err := json.Unmarshal(incompleteResp.Body.Bytes(), &incompleteErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := incompleteErrorResp["success"].(bool); success {
incompleteErrorResponse := assertJSONResponse(t, incompleteResponse, http.StatusBadRequest)
if success, _ := incompleteErrorResponse["success"].(bool); success {
t.Error("Expected error response to have success=false")
}
if errorMsg, ok := incompleteErrorResp["error"].(string); !ok || errorMsg == "" {
if errorMsg, ok := incompleteErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain validation error message")
}
})
@@ -186,28 +132,17 @@ func TestIntegration_Handlers(t *testing.T) {
emailSender.Reset()
user := createAuthenticatedUser(t, authService, userRepo, "post_user", "post@example.com")
postData := map[string]string{
postResponse := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "Handler Test Post",
"url": "https://example.com/handler-test",
"content": "This is a handler test post",
}
postBody, _ := json.Marshal(postData)
postReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(postBody))
postReq.Header.Set("Content-Type", "application/json")
postReq.Header.Set("Authorization", "Bearer "+user.Token)
postReq = testutils.WithUserContext(postReq, middleware.UserIDKey, user.User.ID)
postResp := httptest.NewRecorder()
postHandler.CreatePost(postResp, postReq)
if postResp.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", postResp.Code)
}, user, nil)
if postResponse.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", postResponse.Code)
}
var postResult map[string]any
if err := json.Unmarshal(postResp.Body.Bytes(), &postResult); err != nil {
t.Fatalf("Failed to decode post response: %v", err)
}
postDetails, ok := postResult["data"].(map[string]any)
postResult := assertJSONResponse(t, postResponse, http.StatusCreated)
postDetails, ok := getDataFromResponse(postResult)
if !ok {
t.Fatalf("Expected data object in post response, got %v", postResult)
}
@@ -216,87 +151,49 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatal("Expected post ID in response")
}
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", int(postID)), nil)
getReq = testutils.WithURLParams(getReq, map[string]string{"id": fmt.Sprintf("%d", int(postID))})
getResp := httptest.NewRecorder()
postHandler.GetPost(getResp, getReq)
if getResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getResp.Code)
getResponse := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", int(postID)))
if getResponse.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getResponse.Code)
}
postsReq := httptest.NewRequest("GET", "/api/posts", nil)
postsResp := httptest.NewRecorder()
postHandler.GetPosts(postsResp, postsReq)
if postsResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", postsResp.Code)
postsResponse := makeGetRequest(t, ctx.Router, "/api/posts")
if postsResponse.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", postsResponse.Code)
}
searchReq := httptest.NewRequest("GET", "/api/posts/search?q=handler", nil)
searchResp := httptest.NewRecorder()
postHandler.SearchPosts(searchResp, searchReq)
if searchResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", searchResp.Code)
searchResponse := makeGetRequest(t, ctx.Router, "/api/posts/search?q=handler")
if searchResponse.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", searchResponse.Code)
}
})
t.Run("Post_Handler_Security_Validation", func(t *testing.T) {
emailSender.Reset()
postData := map[string]string{
postResponse := makePostRequestWithJSON(t, ctx.Router, "/api/posts", map[string]any{
"title": "Unauthorized Post",
"url": "https://example.com/unauthorized",
"content": "This should fail",
}
postBody, _ := json.Marshal(postData)
postReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(postBody))
postReq.Header.Set("Content-Type", "application/json")
postResp := httptest.NewRecorder()
postHandler.CreatePost(postResp, postReq)
if postResp.Code != http.StatusUnauthorized {
t.Errorf("Expected status 401 for unauthenticated post creation, got %d", postResp.Code)
}
var authErrorResp map[string]any
if err := json.Unmarshal(postResp.Body.Bytes(), &authErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := authErrorResp["success"].(bool); success {
})
authErrorResponse := assertJSONResponse(t, postResponse, http.StatusUnauthorized)
if success, _ := authErrorResponse["success"].(bool); success {
t.Error("Expected error response to have success=false")
}
if errorMsg, ok := authErrorResp["error"].(string); !ok || errorMsg == "" {
if errorMsg, ok := authErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain authentication error message")
}
user := createAuthenticatedUser(t, authService, userRepo, "security_user", "security@example.com")
invalidData := map[string]string{
invalidResponse := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "",
"url": "not-a-url",
"content": "Invalid post",
}
invalidBody, _ := json.Marshal(invalidData)
invalidReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(invalidBody))
invalidReq.Header.Set("Content-Type", "application/json")
invalidReq.Header.Set("Authorization", "Bearer "+user.Token)
invalidReq = testutils.WithUserContext(invalidReq, middleware.UserIDKey, user.User.ID)
invalidResp := httptest.NewRecorder()
postHandler.CreatePost(invalidResp, invalidReq)
if invalidResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid post data, got %d", invalidResp.Code)
}
var postValidationErrorResp map[string]any
if err := json.Unmarshal(invalidResp.Body.Bytes(), &postValidationErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := postValidationErrorResp["success"].(bool); success {
}, user, nil)
postValidationErrorResponse := assertJSONResponse(t, invalidResponse, http.StatusBadRequest)
if success, _ := postValidationErrorResponse["success"].(bool); success {
t.Error("Expected error response to have success=false")
}
if errorMsg, ok := postValidationErrorResp["error"].(string); !ok || errorMsg == "" {
if errorMsg, ok := postValidationErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain validation error message")
}
})
@@ -306,161 +203,100 @@ func TestIntegration_Handlers(t *testing.T) {
user := createAuthenticatedUser(t, authService, userRepo, "vote_handler_user", "vote_handler@example.com")
post := testutils.CreatePostWithRepo(t, postRepo, user.User.ID, "Vote Handler Test Post", "https://example.com/vote-handler")
voteData := map[string]string{
"type": "up",
}
voteBody, _ := json.Marshal(voteData)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBody))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteResp := httptest.NewRecorder()
voteResponse := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, voteResponse, http.StatusOK)
voteHandler.CastVote(voteResp, voteReq)
if voteResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", voteResp.Code)
}
getVoteResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, getVoteResponse, http.StatusOK)
getVoteReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
getVoteReq.Header.Set("Authorization", "Bearer "+user.Token)
getVoteReq = testutils.WithUserContext(getVoteReq, middleware.UserIDKey, user.User.ID)
getVoteReq = testutils.WithURLParams(getVoteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVoteResp := httptest.NewRecorder()
getPostVotesResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/votes", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, getPostVotesResponse, http.StatusOK)
voteHandler.GetUserVote(getVoteResp, getVoteReq)
if getVoteResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getVoteResp.Code)
}
getPostVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getPostVotesReq.Header.Set("Authorization", "Bearer "+user.Token)
getPostVotesReq = testutils.WithUserContext(getPostVotesReq, middleware.UserIDKey, user.User.ID)
getPostVotesReq = testutils.WithURLParams(getPostVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getPostVotesResp := httptest.NewRecorder()
voteHandler.GetPostVotes(getPostVotesResp, getPostVotesReq)
if getPostVotesResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getPostVotesResp.Code)
}
removeVoteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
removeVoteReq.Header.Set("Authorization", "Bearer "+user.Token)
removeVoteReq = testutils.WithUserContext(removeVoteReq, middleware.UserIDKey, user.User.ID)
removeVoteReq = testutils.WithURLParams(removeVoteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
removeVoteResp := httptest.NewRecorder()
voteHandler.RemoveVote(removeVoteResp, removeVoteReq)
if removeVoteResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", removeVoteResp.Code)
}
removeVoteResponse := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, removeVoteResponse, http.StatusOK)
})
t.Run("User_Handler_Complete_Workflow", func(t *testing.T) {
emailSender.Reset()
user := createAuthenticatedUser(t, authService, userRepo, "user_handler_user", "user_handler@example.com")
usersReq := httptest.NewRequest("GET", "/api/users", nil)
usersReq.Header.Set("Authorization", "Bearer "+user.Token)
usersReq = testutils.WithUserContext(usersReq, middleware.UserIDKey, user.User.ID)
usersResp := httptest.NewRecorder()
usersResponse := makeAuthenticatedGetRequest(t, ctx.Router, "/api/users", user, nil)
assertStatus(t, usersResponse, http.StatusOK)
userHandler.GetUsers(usersResp, usersReq)
if usersResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", usersResp.Code)
}
getUserResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
assertStatus(t, getUserResponse, http.StatusOK)
getUserReq := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d", user.User.ID), nil)
getUserReq.Header.Set("Authorization", "Bearer "+user.Token)
getUserReq = testutils.WithUserContext(getUserReq, middleware.UserIDKey, user.User.ID)
getUserReq = testutils.WithURLParams(getUserReq, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
getUserResp := httptest.NewRecorder()
userHandler.GetUser(getUserResp, getUserReq)
if getUserResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getUserResp.Code)
}
getUserPostsReq := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d/posts", user.User.ID), nil)
getUserPostsReq.Header.Set("Authorization", "Bearer "+user.Token)
getUserPostsReq = testutils.WithUserContext(getUserPostsReq, middleware.UserIDKey, user.User.ID)
getUserPostsReq = testutils.WithURLParams(getUserPostsReq, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
getUserPostsResp := httptest.NewRecorder()
userHandler.GetUserPosts(getUserPostsResp, getUserPostsReq)
if getUserPostsResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getUserPostsResp.Code)
}
getUserPostsResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d/posts", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
assertStatus(t, getUserPostsResponse, http.StatusOK)
})
t.Run("Error_Handling_Invalid_Requests", func(t *testing.T) {
invalidJSONReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer([]byte("invalid json")))
invalidJSONReq.Header.Set("Content-Type", "application/json")
invalidJSONResp := httptest.NewRecorder()
middleware.StopAllRateLimiters()
ctx.Suite.EmailSender.Reset()
authHandler.Register(invalidJSONResp, invalidJSONReq)
if invalidJSONResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid JSON, got %d", invalidJSONResp.Code)
}
var jsonErrorResp map[string]any
if err := json.Unmarshal(invalidJSONResp.Body.Bytes(), &jsonErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := jsonErrorResp["success"].(bool); success {
invalidJSONResponse := makeRequest(t, ctx.Router, "POST", "/api/auth/register", []byte("invalid json"), map[string]string{"Content-Type": "application/json"})
jsonErrorResponse := assertJSONResponse(t, invalidJSONResponse, http.StatusBadRequest)
if success, _ := jsonErrorResponse["success"].(bool); success {
t.Error("Expected error response to have success=false")
}
if errorMsg, ok := jsonErrorResp["error"].(string); !ok || errorMsg == "" {
if errorMsg, ok := jsonErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain JSON parsing error message")
}
missingCTData := map[string]string{
"username": "missing_ct_user",
"email": "missing_ct@example.com",
"username": uniqueTestUsername(t, "missing_ct"),
"email": uniqueTestEmail(t, "missing_ct"),
"password": "SecurePass123!",
}
missingCTBody, _ := json.Marshal(missingCTData)
missingCTReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(missingCTBody))
missingCTResp := httptest.NewRecorder()
missingCTRequest := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(missingCTBody))
missingCTResponse := httptest.NewRecorder()
authHandler.Register(missingCTResp, missingCTReq)
if missingCTResp.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", missingCTResp.Code)
ctx.Router.ServeHTTP(missingCTResponse, missingCTRequest)
if missingCTResponse.Code == http.StatusTooManyRequests {
var rateLimitResponse map[string]any
if err := json.Unmarshal(missingCTResponse.Body.Bytes(), &rateLimitResponse); err != nil {
t.Errorf("Rate limited but response is not valid JSON: %v", err)
} else {
t.Logf("Rate limit hit (expected in full test suite run), but request was processed correctly (not rejected as invalid JSON)")
}
} else if missingCTResponse.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", missingCTResponse.Code)
}
invalidEndpointReq := httptest.NewRequest("GET", "/api/invalid/endpoint", nil)
invalidEndpointResp := httptest.NewRecorder()
invalidEndpointRequest := httptest.NewRequest("GET", "/api/invalid/endpoint", nil)
invalidEndpointResponse := httptest.NewRecorder()
authHandler.Me(invalidEndpointResp, invalidEndpointReq)
if invalidEndpointResp.Code == http.StatusOK {
ctx.Router.ServeHTTP(invalidEndpointResponse, invalidEndpointRequest)
if invalidEndpointResponse.Code == http.StatusOK {
t.Error("Expected error for invalid endpoint")
}
})
t.Run("Security_Authentication_Bypass", func(t *testing.T) {
meReq := httptest.NewRequest("GET", "/api/auth/me", nil)
meResp := httptest.NewRecorder()
meRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
meResponse := httptest.NewRecorder()
authHandler.Me(meResp, meReq)
if meResp.Code == http.StatusOK {
ctx.Router.ServeHTTP(meResponse, meRequest)
if meResponse.Code == http.StatusOK {
t.Error("Expected error for unauthenticated request")
}
invalidTokenReq := httptest.NewRequest("GET", "/api/auth/me", nil)
invalidTokenReq.Header.Set("Authorization", "Bearer invalid-token")
invalidTokenResp := httptest.NewRecorder()
invalidTokenRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
invalidTokenRequest.Header.Set("Authorization", "Bearer invalid-token")
invalidTokenResponse := httptest.NewRecorder()
authHandler.Me(invalidTokenResp, invalidTokenReq)
if invalidTokenResp.Code == http.StatusOK {
ctx.Router.ServeHTTP(invalidTokenResponse, invalidTokenRequest)
if invalidTokenResponse.Code == http.StatusOK {
t.Error("Expected error for invalid token")
}
malformedTokenReq := httptest.NewRequest("GET", "/api/auth/me", nil)
malformedTokenReq.Header.Set("Authorization", "InvalidFormat token")
malformedTokenResp := httptest.NewRecorder()
malformedTokenRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
malformedTokenRequest.Header.Set("Authorization", "InvalidFormat token")
malformedTokenResponse := httptest.NewRecorder()
authHandler.Me(malformedTokenResp, malformedTokenReq)
if malformedTokenResp.Code == http.StatusOK {
ctx.Router.ServeHTTP(malformedTokenResponse, malformedTokenRequest)
if malformedTokenResponse.Code == http.StatusOK {
t.Error("Expected error for malformed token")
}
})
@@ -468,32 +304,21 @@ func TestIntegration_Handlers(t *testing.T) {
t.Run("Security_Input_Sanitization", func(t *testing.T) {
user := createAuthenticatedUser(t, authService, userRepo, "xss_user", "xss@example.com")
xssData := map[string]string{
xssResponse := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "<script>alert('xss')</script>",
"url": "https://example.com/xss",
"content": "XSS test content",
}
xssBody, _ := json.Marshal(xssData)
xssReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(xssBody))
xssReq.Header.Set("Content-Type", "application/json")
xssReq.Header.Set("Authorization", "Bearer "+user.Token)
xssReq = testutils.WithUserContext(xssReq, middleware.UserIDKey, user.User.ID)
xssResp := httptest.NewRecorder()
postHandler.CreatePost(xssResp, xssReq)
if xssResp.Code != http.StatusCreated {
t.Errorf("Expected status 201 for XSS sanitization, got %d", xssResp.Code)
"content": "<script>alert('xss')</script>",
}, user, nil)
if xssResponse.Code != http.StatusCreated {
t.Errorf("Expected status 201 for XSS sanitization, got %d", xssResponse.Code)
}
var xssResult map[string]any
if err := json.Unmarshal(xssResp.Body.Bytes(), &xssResult); err != nil {
t.Fatalf("Failed to decode XSS response: %v", err)
}
xssResult := assertJSONResponse(t, xssResponse, http.StatusCreated)
if success, _ := xssResult["success"].(bool); !success {
t.Error("Expected XSS response to have success=true")
}
data, ok := xssResult["data"].(map[string]any)
data, ok := getDataFromResponse(xssResult)
if !ok {
t.Fatalf("Expected data object in XSS response, got %T", xssResult["data"])
}
@@ -522,32 +347,21 @@ func TestIntegration_Handlers(t *testing.T) {
t.Errorf("Expected script tags to be HTML-escaped in content, got: %s", content)
}
sqlData := map[string]string{
sqlResponse := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "'; DROP TABLE posts; --",
"url": "https://example.com/sql",
"content": "SQL injection test",
}
sqlBody, _ := json.Marshal(sqlData)
sqlReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(sqlBody))
sqlReq.Header.Set("Content-Type", "application/json")
sqlReq.Header.Set("Authorization", "Bearer "+user.Token)
sqlReq = testutils.WithUserContext(sqlReq, middleware.UserIDKey, user.User.ID)
sqlResp := httptest.NewRecorder()
postHandler.CreatePost(sqlResp, sqlReq)
if sqlResp.Code != http.StatusCreated {
t.Errorf("Expected status 201 for SQL injection sanitization, got %d", sqlResp.Code)
}, user, nil)
if sqlResponse.Code != http.StatusCreated {
t.Errorf("Expected status 201 for SQL injection sanitization, got %d", sqlResponse.Code)
}
var sqlResult map[string]any
if err := json.Unmarshal(sqlResp.Body.Bytes(), &sqlResult); err != nil {
t.Fatalf("Failed to decode SQL response: %v", err)
}
sqlResult := assertJSONResponse(t, sqlResponse, http.StatusCreated)
if success, _ := sqlResult["success"].(bool); !success {
t.Error("Expected SQL response to have success=true")
}
sqlResponseData, ok := sqlResult["data"].(map[string]any)
sqlResponseData, ok := getDataFromResponse(sqlResult)
if !ok {
t.Fatalf("Expected data object in SQL response, got %T", sqlResult["data"])
}
@@ -579,63 +393,31 @@ func TestIntegration_Handlers(t *testing.T) {
t.Run("Authorization_User_Access_Control", func(t *testing.T) {
emailSender.Reset()
user1 := createAuthenticatedUser(t, authService, userRepo, "auth_user1", "auth1@example.com")
user2 := createAuthenticatedUser(t, authService, userRepo, "auth_user2", "auth2@example.com")
firstUser := createAuthenticatedUser(t, authService, userRepo, "auth_user1", "auth1@example.com")
secondUser := createAuthenticatedUser(t, authService, userRepo, "auth_user2", "auth2@example.com")
post := testutils.CreatePostWithRepo(t, postRepo, user1.User.ID, "Private Post", "https://example.com/private")
post := testutils.CreatePostWithRepo(t, postRepo, firstUser.User.ID, "Private Post", "https://example.com/private")
getPostReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil)
getPostReq.Header.Set("Authorization", "Bearer "+user2.Token)
getPostReq = testutils.WithUserContext(getPostReq, middleware.UserIDKey, user2.User.ID)
getPostReq = testutils.WithURLParams(getPostReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getPostResp := httptest.NewRecorder()
getPostResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
testutils.AssertHTTPStatus(t, getPostResponse, http.StatusOK)
postHandler.GetPost(getPostResp, getPostReq)
testutils.AssertHTTPStatus(t, getPostResp, http.StatusOK)
updateResponse := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{"title": "Updated Title"}, secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
testutils.AssertHTTPStatus(t, updateResponse, http.StatusForbidden)
updateData := map[string]string{
"title": "Updated Title",
}
updateBody, _ := json.Marshal(updateData)
updateReq := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(updateBody))
updateReq.Header.Set("Content-Type", "application/json")
updateReq.Header.Set("Authorization", "Bearer "+user2.Token)
updateReq = testutils.WithUserContext(updateReq, middleware.UserIDKey, user2.User.ID)
updateReq = testutils.WithURLParams(updateReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
updateResp := httptest.NewRecorder()
postHandler.UpdatePost(updateResp, updateReq)
testutils.AssertHTTPStatus(t, updateResp, http.StatusForbidden)
deleteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil)
deleteReq.Header.Set("Authorization", "Bearer "+user2.Token)
deleteReq = testutils.WithUserContext(deleteReq, middleware.UserIDKey, user2.User.ID)
deleteReq = testutils.WithURLParams(deleteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
deleteResp := httptest.NewRecorder()
postHandler.DeletePost(deleteResp, deleteReq)
testutils.AssertHTTPStatus(t, deleteResp, http.StatusForbidden)
deleteResponse := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
testutils.AssertHTTPStatus(t, deleteResponse, http.StatusForbidden)
})
t.Run("Authorization_Vote_Access_Control", func(t *testing.T) {
emailSender.Reset()
user1 := createAuthenticatedUser(t, authService, userRepo, "vote_auth_user1", "vote_auth1@example.com")
user2 := createAuthenticatedUser(t, authService, userRepo, "vote_auth_user2", "vote_auth2@example.com")
firstUser := createAuthenticatedUser(t, authService, userRepo, "vote_auth_user1", "vote_auth1@example.com")
secondUser := createAuthenticatedUser(t, authService, userRepo, "vote_auth_user2", "vote_auth2@example.com")
post := testutils.CreatePostWithRepo(t, postRepo, user1.User.ID, "Vote Auth Post", "https://example.com/vote-auth")
post := testutils.CreatePostWithRepo(t, postRepo, firstUser.User.ID, "Vote Auth Post", "https://example.com/vote-auth")
voteData := map[string]string{"type": "up"}
voteBody, _ := json.Marshal(voteData)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBody))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user2.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user2.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteResp := httptest.NewRecorder()
voteHandler.CastVote(voteResp, voteReq)
if voteResp.Code != http.StatusOK {
t.Errorf("Users should be able to vote on any post, got %d", voteResp.Code)
voteResponse := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
if voteResponse.Code != http.StatusOK {
t.Errorf("Users should be able to vote on any post, got %d", voteResponse.Code)
}
})
@@ -664,12 +446,8 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatalf("Failed to generate expired token: %v", err)
}
meReq := httptest.NewRequest("GET", "/api/auth/me", nil)
meReq.Header.Set("Authorization", "Bearer "+expiredToken)
meResp := httptest.NewRecorder()
authHandler.Me(meResp, meReq)
testutils.AssertHTTPStatus(t, meResp, http.StatusUnauthorized)
meResponse := makeRequest(t, ctx.Router, "GET", "/api/auth/me", nil, map[string]string{"Authorization": "Bearer " + expiredToken})
testutils.AssertHTTPStatus(t, meResponse, http.StatusUnauthorized)
})
t.Run("Authorization_Token_Tampering", func(t *testing.T) {
@@ -678,12 +456,12 @@ func TestIntegration_Handlers(t *testing.T) {
tamperedToken := user.Token[:len(user.Token)-5] + "XXXXX"
meReq := httptest.NewRequest("GET", "/api/auth/me", nil)
meReq.Header.Set("Authorization", "Bearer "+tamperedToken)
meResp := httptest.NewRecorder()
meRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
meRequest.Header.Set("Authorization", "Bearer "+tamperedToken)
meResponse := httptest.NewRecorder()
authHandler.Me(meResp, meReq)
testutils.AssertHTTPStatus(t, meResp, http.StatusUnauthorized)
ctx.Router.ServeHTTP(meResponse, meRequest)
testutils.AssertHTTPStatus(t, meResponse, http.StatusUnauthorized)
})
t.Run("Authorization_Session_Version_Mismatch", func(t *testing.T) {
@@ -711,12 +489,12 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatalf("Failed to generate invalid token: %v", err)
}
meReq := httptest.NewRequest("GET", "/api/auth/me", nil)
meReq.Header.Set("Authorization", "Bearer "+invalidToken)
meResp := httptest.NewRecorder()
meRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
meRequest.Header.Set("Authorization", "Bearer "+invalidToken)
meResponse := httptest.NewRecorder()
authHandler.Me(meResp, meReq)
testutils.AssertHTTPStatus(t, meResp, http.StatusUnauthorized)
ctx.Router.ServeHTTP(meResponse, meRequest)
testutils.AssertHTTPStatus(t, meResponse, http.StatusUnauthorized)
})
}
@@ -748,11 +526,9 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
}
voteService := services.NewVoteService(voteRepo, postRepo, db)
apiHandler := handlers.NewAPIHandlerWithMonitoring(testutils.AppTestConfig, postRepo, userRepo, voteService, db, monitor)
t.Run("Health endpoint includes database monitoring", func(t *testing.T) {
user := &database.User{
Username: "monitoring_user",
Email: "monitoring@example.com",
@@ -765,7 +541,6 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
recorder := httptest.NewRecorder()
apiHandler.GetHealth(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
var response map[string]any
@@ -777,7 +552,7 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
t.Error("Expected success to be true")
}
data, ok := response["data"].(map[string]any)
data, ok := getDataFromResponse(response)
if !ok {
t.Fatal("Expected data to be a map")
}
@@ -813,7 +588,6 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
recorder := httptest.NewRecorder()
apiHandler.GetMetrics(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
var response map[string]any
@@ -825,7 +599,7 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
t.Error("Expected success to be true")
}
data, ok := response["data"].(map[string]any)
data, ok := getDataFromResponse(response)
if !ok {
t.Fatal("Expected data to be a map")
}

View File

@@ -1,6 +1,7 @@
package integration
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
@@ -282,14 +283,14 @@ func setupPageHandlerTestContext(t *testing.T) *testContext {
func getCSRFToken(t *testing.T, router http.Handler, path string, cookies ...*http.Cookie) string {
t.Helper()
req := httptest.NewRequest("GET", path, nil)
request := httptest.NewRequest("GET", path, nil)
for _, cookie := range cookies {
req.AddCookie(cookie)
request.AddCookie(cookie)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
cookieList := rec.Result().Cookies()
cookieList := recorder.Result().Cookies()
for _, cookie := range cookieList {
if cookie.Name == "csrf_token" {
return cookie.Value
@@ -299,32 +300,32 @@ func getCSRFToken(t *testing.T, router http.Handler, path string, cookies ...*ht
return ""
}
func assertJSONResponse(t *testing.T, rec *httptest.ResponseRecorder, expectedStatus int) map[string]any {
func assertJSONResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedStatus int) map[string]any {
t.Helper()
if rec.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, rec.Code, rec.Body.String())
if recorder.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, recorder.Code, recorder.Body.String())
return nil
}
var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode response: %v. Body: %s", err, rec.Body.String())
if err := json.NewDecoder(recorder.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode response: %v. Body: %s", err, recorder.Body.String())
return nil
}
return response
}
func assertErrorResponse(t *testing.T, rec *httptest.ResponseRecorder, expectedStatus int) {
func assertErrorResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedStatus int) {
t.Helper()
if rec.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, rec.Code, rec.Body.String())
if recorder.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, recorder.Code, recorder.Body.String())
return
}
var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode error response: %v. Body: %s", err, rec.Body.String())
if err := json.NewDecoder(recorder.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode error response: %v. Body: %s", err, recorder.Body.String())
return
}
@@ -335,23 +336,23 @@ func assertErrorResponse(t *testing.T, rec *httptest.ResponseRecorder, expectedS
}
}
func assertStatus(t *testing.T, rec *httptest.ResponseRecorder, expectedStatus int) {
func assertStatus(t *testing.T, recorder *httptest.ResponseRecorder, expectedStatus int) {
t.Helper()
if rec.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, rec.Code, rec.Body.String())
if recorder.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, recorder.Code, recorder.Body.String())
}
}
func assertStatusRange(t *testing.T, rec *httptest.ResponseRecorder, minStatus, maxStatus int) {
func assertStatusRange(t *testing.T, recorder *httptest.ResponseRecorder, minStatus, maxStatus int) {
t.Helper()
if rec.Code < minStatus || rec.Code > maxStatus {
t.Errorf("Expected status between %d and %d, got %d. Body: %s", minStatus, maxStatus, rec.Code, rec.Body.String())
if recorder.Code < minStatus || recorder.Code > maxStatus {
t.Errorf("Expected status between %d and %d, got %d. Body: %s", minStatus, maxStatus, recorder.Code, recorder.Body.String())
}
}
func assertCookie(t *testing.T, rec *httptest.ResponseRecorder, name, expectedValue string) {
func assertCookie(t *testing.T, recorder *httptest.ResponseRecorder, name, expectedValue string) {
t.Helper()
cookies := rec.Result().Cookies()
cookies := recorder.Result().Cookies()
for _, cookie := range cookies {
if cookie.Name == name {
if expectedValue != "" && cookie.Value != expectedValue {
@@ -363,9 +364,9 @@ func assertCookie(t *testing.T, rec *httptest.ResponseRecorder, name, expectedVa
t.Errorf("Expected cookie %s not found", name)
}
func assertCookieCleared(t *testing.T, rec *httptest.ResponseRecorder, name string) {
func assertCookieCleared(t *testing.T, recorder *httptest.ResponseRecorder, name string) {
t.Helper()
cookies := rec.Result().Cookies()
cookies := recorder.Result().Cookies()
for _, cookie := range cookies {
if cookie.Name == name {
if cookie.Value != "" {
@@ -376,21 +377,17 @@ func assertCookieCleared(t *testing.T, rec *httptest.ResponseRecorder, name stri
}
}
func assertHeader(t *testing.T, rec *httptest.ResponseRecorder, name, expectedValue string) {
func assertHeader(t *testing.T, recorder *httptest.ResponseRecorder, name string) {
t.Helper()
actualValue := rec.Header().Get(name)
if expectedValue == "" {
actualValue := recorder.Header().Get(name)
if actualValue == "" {
t.Errorf("Expected header %s to be present", name)
}
} else if actualValue != expectedValue {
t.Errorf("Expected header %s=%s, got %s", name, expectedValue, actualValue)
}
}
func assertHeaderContains(t *testing.T, rec *httptest.ResponseRecorder, name, substring string) {
func assertHeaderContains(t *testing.T, recorder *httptest.ResponseRecorder, name, substring string) {
t.Helper()
actualValue := rec.Header().Get(name)
actualValue := recorder.Header().Get(name)
if !strings.Contains(actualValue, substring) {
t.Errorf("Expected header %s to contain %s, got %s", name, substring, actualValue)
}
@@ -450,3 +447,83 @@ func createUserWithCleanup(t *testing.T, ctx *testContext, username, email strin
})
return user
}
func makeRequest(t *testing.T, router http.Handler, method, path string, body []byte, headers map[string]string) *httptest.ResponseRecorder {
t.Helper()
var requestBody *bytes.Buffer
if body != nil {
requestBody = bytes.NewBuffer(body)
} else {
requestBody = bytes.NewBuffer(nil)
}
request := httptest.NewRequest(method, path, requestBody)
for key, value := range headers {
request.Header.Set(key, value)
}
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
return recorder
}
func makeAuthenticatedRequest(t *testing.T, router http.Handler, method, path string, body []byte, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
var requestBody *bytes.Buffer
if body != nil {
requestBody = bytes.NewBuffer(body)
} else {
requestBody = bytes.NewBuffer(nil)
}
request := httptest.NewRequest(method, path, requestBody)
request.Header.Set("Authorization", "Bearer "+user.Token)
if body != nil {
request.Header.Set("Content-Type", "application/json")
}
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
if urlParams != nil {
request = testutils.WithURLParams(request, urlParams)
}
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
return recorder
}
func makeGetRequest(t *testing.T, router http.Handler, path string) *httptest.ResponseRecorder {
t.Helper()
return makeRequest(t, router, "GET", path, nil, nil)
}
func makeAuthenticatedGetRequest(t *testing.T, router http.Handler, path string, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
return makeAuthenticatedRequest(t, router, "GET", path, nil, user, urlParams)
}
func makePostRequest(t *testing.T, router http.Handler, path string, body map[string]any, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
bodyBytes, _ := json.Marshal(body)
return makeAuthenticatedRequest(t, router, "POST", path, bodyBytes, user, urlParams)
}
func makePutRequest(t *testing.T, router http.Handler, path string, body map[string]any, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
bodyBytes, _ := json.Marshal(body)
return makeAuthenticatedRequest(t, router, "PUT", path, bodyBytes, user, urlParams)
}
func makeDeleteRequest(t *testing.T, router http.Handler, path string, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
return makeAuthenticatedRequest(t, router, "DELETE", path, nil, user, urlParams)
}
func makePostRequestWithJSON(t *testing.T, router http.Handler, path string, body map[string]any) *httptest.ResponseRecorder {
t.Helper()
bodyBytes, _ := json.Marshal(body)
return makeRequest(t, router, "POST", path, bodyBytes, map[string]string{"Content-Type": "application/json"})
}
func getDataFromResponse(response map[string]any) (map[string]any, bool) {
if response == nil {
return nil, false
}
data, ok := response["data"].(map[string]any)
return data, ok
}

View File

@@ -22,26 +22,26 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, authService, ctx.Suite.UserRepo, "settings_email_user", "settings_email@example.com")
getReq := httptest.NewRequest("GET", "/settings", nil)
getReq.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
getRequest := httptest.NewRequest("GET", "/settings", nil)
getRequest.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
getRecorder := httptest.NewRecorder()
router.ServeHTTP(getRecorder, getRequest)
csrfToken := getCSRFToken(t, router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token})
reqBody := url.Values{}
reqBody.Set("email", "newemail@example.com")
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("email", "newemail@example.com")
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/settings/email", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/settings/email", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
})
t.Run("Settings_Username_Update_Form", func(t *testing.T) {
@@ -51,19 +51,19 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token})
reqBody := url.Values{}
reqBody.Set("username", "new_username")
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("username", "new_username")
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/settings/username", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/settings/username", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
})
t.Run("Settings_Password_Update_Form", func(t *testing.T) {
@@ -74,20 +74,20 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, freshCtx.Router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token})
reqBody := url.Values{}
reqBody.Set("current_password", "SecurePass123!")
reqBody.Set("new_password", "NewSecurePass123!")
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("current_password", "SecurePass123!")
requestBody.Set("new_password", "NewSecurePass123!")
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/settings/password", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/settings/password", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req)
freshCtx.Router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
})
t.Run("Logout_Page_Handler", func(t *testing.T) {
@@ -98,19 +98,19 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, freshCtx.Router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token})
reqBody := url.Values{}
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/logout", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/logout", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req)
freshCtx.Router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusSeeOther)
assertCookieCleared(t, rec, "auth_token")
assertStatus(t, recorder, http.StatusSeeOther)
assertCookieCleared(t, recorder, "auth_token")
})
t.Run("Resend_Verification_Page_Handler", func(t *testing.T) {
@@ -120,18 +120,18 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, freshCtx.Router, "/resend-verification")
reqBody := url.Values{}
reqBody.Set("email", "resend_page@example.com")
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("email", "resend_page@example.com")
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/resend-verification", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/resend-verification", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req)
freshCtx.Router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
})
t.Run("Post_Vote_Page_Handler", func(t *testing.T) {
@@ -142,26 +142,26 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
post := testutils.CreatePostWithRepo(t, freshCtx.Suite.PostRepo, user.User.ID, "Vote Page Test", "https://example.com/vote-page")
getReq := httptest.NewRequest("GET", fmt.Sprintf("/posts/%d", post.ID), nil)
getRec := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(getRec, getReq)
getRequest := httptest.NewRequest("GET", fmt.Sprintf("/posts/%d", post.ID), nil)
getRecorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(getRecorder, getRequest)
csrfToken := getCSRFToken(t, freshCtx.Router, fmt.Sprintf("/posts/%d", post.ID))
reqBody := url.Values{}
reqBody.Set("action", "up")
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("action", "up")
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", fmt.Sprintf("/posts/%d/vote", post.ID), strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", fmt.Sprintf("/posts/%d/vote", post.ID), strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
request = testutils.WithURLParams(request, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req)
freshCtx.Router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
})
t.Run("Login_Page_Handler_Workflow", func(t *testing.T) {
@@ -172,20 +172,20 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, freshCtx.Router, "/login")
reqBody := url.Values{}
reqBody.Set("username", "login_page_user")
reqBody.Set("password", "SecurePass123!")
reqBody.Set("csrf_token", csrfToken)
requestBody := url.Values{}
requestBody.Set("username", "login_page_user")
requestBody.Set("password", "SecurePass123!")
requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/login", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/login", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req)
freshCtx.Router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusSeeOther)
assertCookie(t, rec, "auth_token", "")
assertStatus(t, recorder, http.StatusSeeOther)
assertCookie(t, recorder, "auth_token", "")
})
t.Run("Email_Confirmation_Page_Handler", func(t *testing.T) {
@@ -198,11 +198,11 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
token = "test-token"
}
req := httptest.NewRequest("GET", "/confirm?token="+url.QueryEscape(token), nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/confirm?token="+url.QueryEscape(token), nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
})
}

View File

@@ -16,63 +16,62 @@ func TestIntegration_PageHandler(t *testing.T) {
router := ctx.Router
t.Run("Home_Page_Renders", func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, recorder, http.StatusOK)
if !strings.Contains(rec.Body.String(), "<html") {
if !strings.Contains(recorder.Body.String(), "<html") {
t.Error("Expected HTML content")
}
})
t.Run("Login_Form_Renders", func(t *testing.T) {
req := httptest.NewRequest("GET", "/login", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/login", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, recorder, http.StatusOK)
body := rec.Body.String()
body := recorder.Body.String()
if !strings.Contains(body, "login") && !strings.Contains(body, "Login") {
t.Error("Expected login form content")
}
})
t.Run("Register_Form_Renders", func(t *testing.T) {
req := httptest.NewRequest("GET", "/register", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/register", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatus(t, recorder, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
body := rec.Body.String()
body := recorder.Body.String()
if !strings.Contains(body, "register") && !strings.Contains(body, "Register") {
t.Error("Expected register form content")
}
})
t.Run("PageHandler_With_CSRF_Token", func(t *testing.T) {
req := httptest.NewRequest("GET", "/register", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/register", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertCookie(t, rec, "csrf_token", "")
assertCookie(t, recorder, "csrf_token", "")
})
t.Run("PageHandler_Form_Submission", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
getReq := httptest.NewRequest("GET", "/register", nil)
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
getRequest := httptest.NewRequest("GET", "/register", nil)
getRecorder := httptest.NewRecorder()
router.ServeHTTP(getRecorder, getRequest)
cookies := getRec.Result().Cookies()
cookies := getRecorder.Result().Cookies()
var csrfCookie *http.Cookie
for _, cookie := range cookies {
if cookie.Name == "csrf_token" {
@@ -85,33 +84,33 @@ func TestIntegration_PageHandler(t *testing.T) {
t.Fatal("Expected CSRF cookie")
}
reqBody := url.Values{}
reqBody.Set("username", "page_form_user")
reqBody.Set("email", "page_form@example.com")
reqBody.Set("password", "SecurePass123!")
reqBody.Set("csrf_token", csrfCookie.Value)
requestBody := url.Values{}
requestBody.Set("username", "page_form_user")
requestBody.Set("email", "page_form@example.com")
requestBody.Set("password", "SecurePass123!")
requestBody.Set("csrf_token", csrfCookie.Value)
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
})
t.Run("PageHandler_Authenticated_Access", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createUserWithCleanup(t, ctx, "page_auth_user", "page_auth@example.com")
req := httptest.NewRequest("GET", "/settings", nil)
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/settings", nil)
request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, recorder, http.StatusOK)
})
t.Run("PageHandler_Post_Display", func(t *testing.T) {
@@ -120,34 +119,34 @@ func TestIntegration_PageHandler(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Page Test Post", "https://example.com/page-test")
req := httptest.NewRequest("GET", "/posts/"+fmt.Sprintf("%d", post.ID), nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/posts/"+fmt.Sprintf("%d", post.ID), nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, recorder, http.StatusOK)
body := rec.Body.String()
body := recorder.Body.String()
if !strings.Contains(body, "Page Test Post") {
t.Error("Expected post title in page")
}
})
t.Run("PageHandler_Search_Page", func(t *testing.T) {
req := httptest.NewRequest("GET", "/search?q=test", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/search?q=test", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, recorder, http.StatusOK)
})
t.Run("PageHandler_Error_Handling", func(t *testing.T) {
req := httptest.NewRequest("GET", "/nonexistent", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/nonexistent", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusNotFound)
assertStatus(t, recorder, http.StatusNotFound)
})
}

View File

@@ -1,8 +1,6 @@
package integration
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
@@ -32,17 +30,12 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
t.Fatalf("Failed to create user: %v", err)
}
reqBody := map[string]string{
reqBody := map[string]any{
"username_or_email": "reset_user",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/forgot-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, router, "/api/auth/forgot-password", reqBody)
router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, request, http.StatusOK)
if response != nil {
if success, ok := response["success"].(bool); !ok || !success {
t.Error("Expected success=true")
@@ -77,18 +70,13 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
t.Fatal("Expected password reset token")
}
reqBody := map[string]string{
reqBody := map[string]any{
"token": resetToken,
"new_password": "NewPassword123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, router, "/api/auth/reset-password", reqBody)
router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
loginResult, err := ctx.AuthService.Login("reset_complete_user", "NewPassword123!")
if err != nil {
@@ -120,14 +108,14 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
reqBody := url.Values{}
reqBody.Set("username_or_email", "page_reset_user")
reqBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/forgot-password", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/forgot-password", strings.NewReader(reqBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
recorder := httptest.NewRecorder()
pageRouter.ServeHTTP(rec, req)
pageRouter.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther)
assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
resetToken := pageCtx.Suite.EmailSender.PasswordResetToken()
if resetToken == "" {
@@ -166,33 +154,23 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
t.Fatalf("Failed to update user: %v", err)
}
reqBody := map[string]string{
reqBody := map[string]any{
"token": resetToken,
"new_password": "NewPassword123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, router, "/api/auth/reset-password", reqBody)
router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusBadRequest)
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("PasswordReset_InvalidToken", func(t *testing.T) {
reqBody := map[string]string{
reqBody := map[string]any{
"token": "invalid-token",
"new_password": "NewPassword123!",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, router, "/api/auth/reset-password", reqBody)
router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusBadRequest)
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("PasswordReset_WeakPassword", func(t *testing.T) {
@@ -214,18 +192,13 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
resetToken := ctx.Suite.EmailSender.PasswordResetToken()
reqBody := map[string]string{
reqBody := map[string]any{
"token": resetToken,
"new_password": "123",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, router, "/api/auth/reset-password", reqBody)
router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusBadRequest)
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("PasswordReset_EmailIntegration", func(t *testing.T) {
@@ -243,17 +216,12 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
t.Fatalf("Failed to create user: %v", err)
}
reqBody := map[string]string{
reqBody := map[string]any{
"username_or_email": "email_reset@example.com",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/forgot-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, freshCtx.Router, "/api/auth/forgot-password", reqBody)
freshCtx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
resetToken := freshCtx.Suite.EmailSender.PasswordResetToken()
if resetToken == "" {

View File

@@ -51,24 +51,24 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 2; i++ {
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
}
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests)
assertErrorResponse(t, recorder, http.StatusTooManyRequests)
assertHeader(t, rec, "Retry-After", "")
assertHeader(t, recorder, "Retry-After")
var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err == nil {
if err := json.NewDecoder(recorder.Body).Decode(&response); err == nil {
if _, exists := response["retry_after"]; !exists {
t.Error("Expected retry_after in response")
}
@@ -81,17 +81,17 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 5; i++ {
req := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("GET", "/api/posts", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
}
req := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/api/posts", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests)
assertErrorResponse(t, recorder, http.StatusTooManyRequests)
})
t.Run("Health_RateLimit_Enforced", func(t *testing.T) {
@@ -100,17 +100,17 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 3; i++ {
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("GET", "/health", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
}
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/health", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests)
assertErrorResponse(t, recorder, http.StatusTooManyRequests)
})
t.Run("Metrics_RateLimit_Enforced", func(t *testing.T) {
@@ -119,17 +119,17 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 2; i++ {
req := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("GET", "/metrics", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
}
req := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/metrics", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests)
assertErrorResponse(t, recorder, http.StatusTooManyRequests)
})
t.Run("RateLimit_Different_Endpoints_Independent", func(t *testing.T) {
@@ -139,17 +139,17 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 2; i++ {
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
}
req := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("GET", "/api/posts", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, recorder, http.StatusOK)
})
t.Run("RateLimit_With_Authentication", func(t *testing.T) {
@@ -166,20 +166,20 @@ func TestIntegration_RateLimiting(t *testing.T) {
user := createAuthenticatedUser(t, authService, suite.UserRepo, uniqueTestUsername(t, "ratelimit_auth"), uniqueTestEmail(t, "ratelimit_auth"))
for i := 0; i < 3; i++ {
req := httptest.NewRequest("GET", "/api/auth/me", nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := httptest.NewRequest("GET", "/api/auth/me", nil)
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
}
req := httptest.NewRequest("GET", "/api/auth/me", nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/api/auth/me", nil)
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests)
assertErrorResponse(t, recorder, http.StatusTooManyRequests)
})
}

View File

@@ -17,35 +17,29 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
router := ctx.Router
t.Run("SecurityHeaders_Present", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/health")
router.ServeHTTP(rec, req)
assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
assertHeader(t, rec, "X-Content-Type-Options", "")
assertHeader(t, rec, "X-Frame-Options", "")
assertHeader(t, rec, "X-XSS-Protection", "")
assertHeader(t, request, "X-Content-Type-Options")
assertHeader(t, request, "X-Frame-Options")
assertHeader(t, request, "X-XSS-Protection")
})
t.Run("CORS_Headers_Present", func(t *testing.T) {
req := httptest.NewRequest("OPTIONS", "/api/posts", nil)
req.Header.Set("Origin", "http://localhost:3000")
rec := httptest.NewRecorder()
request := httptest.NewRequest("OPTIONS", "/api/posts", nil)
request.Header.Set("Origin", "http://localhost:3000")
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
assertHeader(t, rec, "Access-Control-Allow-Origin", "")
assertHeader(t, recorder, "Access-Control-Allow-Origin")
})
t.Run("Logging_Middleware_Executes", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/health")
router.ServeHTTP(rec, req)
if rec.Code == 0 {
if request.Code == 0 {
t.Error("Expected logging middleware to execute")
}
})
@@ -53,27 +47,24 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
t.Run("RequestSizeLimit_Enforced", func(t *testing.T) {
user := createUserWithCleanup(t, ctx, "size_limit_user", "size_limit@example.com")
largeBody := strings.Repeat("a", 10*1024*1024)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBufferString(largeBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBufferString(largeBody))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Code != http.StatusRequestEntityTooLarge && rec.Code != http.StatusBadRequest {
t.Errorf("Expected status 413 or 400 for oversized request, got %d. Body: %s", rec.Code, rec.Body.String())
if recorder.Code != http.StatusRequestEntityTooLarge && recorder.Code != http.StatusBadRequest {
t.Errorf("Expected status 413 or 400 for oversized request, got %d. Body: %s", recorder.Code, recorder.Body.String())
}
})
t.Run("DBMonitoring_Active", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
request := makeGetRequest(t, router, "/health")
var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err == nil {
if err := json.NewDecoder(request.Body).Decode(&response); err == nil {
if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["database_stats"]; !exists {
t.Error("Expected database_stats in health response")
@@ -83,12 +74,9 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
})
t.Run("Metrics_Middleware_Executes", func(t *testing.T) {
req := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/metrics")
router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, request, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["database"]; !exists {
@@ -99,34 +87,25 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
})
t.Run("StaticFiles_Served", func(t *testing.T) {
req := httptest.NewRequest("GET", "/robots.txt", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/robots.txt")
router.ServeHTTP(rec, req)
assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
if !strings.Contains(rec.Body.String(), "User-agent") {
if !strings.Contains(request.Body.String(), "User-agent") {
t.Error("Expected robots.txt content")
}
})
t.Run("API_Routes_Accessible", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/api/posts")
router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
})
t.Run("Health_Endpoint_Accessible", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/health")
router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, request, http.StatusOK)
if response != nil {
if success, ok := response["success"].(bool); !ok || !success {
t.Error("Expected success=true in health response")
@@ -135,40 +114,33 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
})
t.Run("Middleware_Order_Correct", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder()
request := makeGetRequest(t, router, "/api/posts")
router.ServeHTTP(rec, req)
assertHeader(t, request, "X-Content-Type-Options")
assertHeader(t, rec, "X-Content-Type-Options", "")
if rec.Code == 0 {
if request.Code == 0 {
t.Error("Response should have status code")
}
})
t.Run("Compression_Middleware_Active", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil)
req.Header.Set("Accept-Encoding", "gzip")
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/api/posts", nil)
request.Header.Set("Accept-Encoding", "gzip")
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Header().Get("Content-Encoding") == "" {
if recorder.Header().Get("Content-Encoding") == "" {
t.Log("Compression may not be applied to small responses")
}
})
t.Run("Cache_Middleware_Active", func(t *testing.T) {
req1 := httptest.NewRequest("GET", "/api/posts", nil)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
firstRequest := makeGetRequest(t, router, "/api/posts")
req2 := httptest.NewRequest("GET", "/api/posts", nil)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
secondRequest := makeGetRequest(t, router, "/api/posts")
if rec1.Code != rec2.Code {
if firstRequest.Code != secondRequest.Code {
t.Error("Cached responses should have same status")
}
})
@@ -177,35 +149,23 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createUserWithCleanup(t, ctx, "auth_middleware_user", "auth_middleware@example.com")
req := httptest.NewRequest("GET", "/api/auth/me", nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := makeAuthenticatedGetRequest(t, router, "/api/auth/me", user, nil)
router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, request, http.StatusOK)
})
t.Run("RateLimit_Middleware_Integration", func(t *testing.T) {
rateLimitCtx := setupTestContext(t)
rateLimitRouter := rateLimitCtx.Router
for i := 0; i < 3; i++ {
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
rateLimitRouter.ServeHTTP(rec, req)
for range 3 {
request := makePostRequestWithJSON(t, rateLimitRouter, "/api/auth/login", map[string]any{"username": "test", "password": "test"})
_ = request
}
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
request := makePostRequestWithJSON(t, rateLimitRouter, "/api/auth/login", map[string]any{"username": "test", "password": "test"})
rateLimitRouter.ServeHTTP(rec, req)
if rec.Code == http.StatusTooManyRequests {
if request.Code == http.StatusTooManyRequests {
t.Log("Rate limiting is working")
}
})

View File

@@ -21,60 +21,60 @@ func TestIntegration_SessionManagement(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "session_pass_user", "session_pass@example.com")
req1 := httptest.NewRequest("GET", "/api/auth/me", nil)
req1.Header.Set("Authorization", "Bearer "+user.Token)
req1 = testutils.WithUserContext(req1, middleware.UserIDKey, user.User.ID)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
firstRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
firstRequest.Header.Set("Authorization", "Bearer "+user.Token)
firstRequest = testutils.WithUserContext(firstRequest, middleware.UserIDKey, user.User.ID)
firstRecorder := httptest.NewRecorder()
router.ServeHTTP(firstRecorder, firstRequest)
assertStatus(t, rec1, http.StatusOK)
assertStatus(t, firstRecorder, http.StatusOK)
reqBody := map[string]string{
requestBody := map[string]string{
"current_password": "SecurePass123!",
"new_password": "NewSecurePass123!",
}
body, _ := json.Marshal(reqBody)
req2 := httptest.NewRequest("PUT", "/api/auth/password", bytes.NewBuffer(body))
req2.Header.Set("Content-Type", "application/json")
req2.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
body, _ := json.Marshal(requestBody)
secondRequest := httptest.NewRequest("PUT", "/api/auth/password", bytes.NewBuffer(body))
secondRequest.Header.Set("Content-Type", "application/json")
secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondRequest)
assertStatus(t, rec2, http.StatusOK)
assertStatus(t, secondRecorder, http.StatusOK)
req3 := httptest.NewRequest("GET", "/api/auth/me", nil)
req3.Header.Set("Authorization", "Bearer "+user.Token)
req3 = testutils.WithUserContext(req3, middleware.UserIDKey, user.User.ID)
rec3 := httptest.NewRecorder()
router.ServeHTTP(rec3, req3)
thirdRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
thirdRequest.Header.Set("Authorization", "Bearer "+user.Token)
thirdRequest = testutils.WithUserContext(thirdRequest, middleware.UserIDKey, user.User.ID)
thirdRecorder := httptest.NewRecorder()
router.ServeHTTP(thirdRecorder, thirdRequest)
assertErrorResponse(t, rec3, http.StatusUnauthorized)
assertErrorResponse(t, thirdRecorder, http.StatusUnauthorized)
})
t.Run("Session_Invalidation_On_Account_Lock", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "session_lock_user", "session_lock@example.com")
req1 := httptest.NewRequest("GET", "/api/auth/me", nil)
req1.Header.Set("Authorization", "Bearer "+user.Token)
req1 = testutils.WithUserContext(req1, middleware.UserIDKey, user.User.ID)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
firstRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
firstRequest.Header.Set("Authorization", "Bearer "+user.Token)
firstRequest = testutils.WithUserContext(firstRequest, middleware.UserIDKey, user.User.ID)
firstRecorder := httptest.NewRecorder()
router.ServeHTTP(firstRecorder, firstRequest)
assertStatus(t, rec1, http.StatusOK)
assertStatus(t, firstRecorder, http.StatusOK)
if err := ctx.Suite.UserRepo.Lock(user.User.ID); err != nil {
t.Fatalf("Failed to lock user: %v", err)
}
req2 := httptest.NewRequest("GET", "/api/auth/me", nil)
req2.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
secondRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondRequest)
assertErrorResponse(t, rec2, http.StatusUnauthorized)
assertErrorResponse(t, secondRecorder, http.StatusUnauthorized)
})
t.Run("Refresh_Token_Revocation", func(t *testing.T) {
@@ -90,48 +90,48 @@ func TestIntegration_SessionManagement(t *testing.T) {
t.Fatal("Expected refresh token")
}
reqBody := map[string]string{
requestBody := map[string]string{
"refresh_token": loginResult.RefreshToken,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
body, _ := json.Marshal(requestBody)
request := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK)
assertStatus(t, recorder, http.StatusOK)
if err := ctx.AuthService.RevokeRefreshToken(loginResult.RefreshToken); err != nil {
t.Fatalf("Failed to revoke token: %v", err)
}
req2 := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body))
req2.Header.Set("Content-Type", "application/json")
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
secondRequest := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body))
secondRequest.Header.Set("Content-Type", "application/json")
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondRequest)
assertErrorResponse(t, rec2, http.StatusUnauthorized)
assertErrorResponse(t, secondRecorder, http.StatusUnauthorized)
})
t.Run("Multiple_Sessions_Independent", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user1 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "multi_session_user1", "multi_session1@example.com")
user2 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "multi_session_user2", "multi_session2@example.com")
firstUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "multi_session_user1", "multi_session1@example.com")
secondUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "multi_session_user2", "multi_session2@example.com")
req1 := httptest.NewRequest("GET", "/api/auth/me", nil)
req1.Header.Set("Authorization", "Bearer "+user1.Token)
req1 = testutils.WithUserContext(req1, middleware.UserIDKey, user1.User.ID)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
firstRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
firstRequest.Header.Set("Authorization", "Bearer "+firstUser.Token)
firstRequest = testutils.WithUserContext(firstRequest, middleware.UserIDKey, firstUser.User.ID)
firstRecorder := httptest.NewRecorder()
router.ServeHTTP(firstRecorder, firstRequest)
req2 := httptest.NewRequest("GET", "/api/auth/me", nil)
req2.Header.Set("Authorization", "Bearer "+user2.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user2.User.ID)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
secondRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
secondRequest.Header.Set("Authorization", "Bearer "+secondUser.Token)
secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, secondUser.User.ID)
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondRequest)
assertStatus(t, rec1, http.StatusOK)
assertStatus(t, rec2, http.StatusOK)
assertStatus(t, firstRecorder, http.StatusOK)
assertStatus(t, secondRecorder, http.StatusOK)
})
}
@@ -144,17 +144,17 @@ func TestIntegration_AccountDeletion(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "del_flow_user", "del_flow@example.com")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Test Post", "https://example.com")
reqBody := map[string]string{}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
requestBody := map[string]string{}
body, _ := json.Marshal(requestBody)
request := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, recorder, http.StatusOK)
if response == nil {
return
}
@@ -171,13 +171,13 @@ func TestIntegration_AccountDeletion(t *testing.T) {
"token": deletionToken,
}
confirmBodyBytes, _ := json.Marshal(confirmBody)
confirmReq := httptest.NewRequest("POST", "/api/auth/account/confirm", bytes.NewBuffer(confirmBodyBytes))
confirmReq.Header.Set("Content-Type", "application/json")
confirmRec := httptest.NewRecorder()
confirmRequest := httptest.NewRequest("POST", "/api/auth/account/confirm", bytes.NewBuffer(confirmBodyBytes))
confirmRequest.Header.Set("Content-Type", "application/json")
confirmRecorder := httptest.NewRecorder()
router.ServeHTTP(confirmRec, confirmReq)
router.ServeHTTP(confirmRecorder, confirmRequest)
confirmResponse := assertJSONResponse(t, confirmRec, http.StatusOK)
confirmResponse := assertJSONResponse(t, confirmRecorder, http.StatusOK)
if confirmResponse == nil {
return
}
@@ -209,17 +209,17 @@ func TestIntegration_AccountDeletion(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "del_posts_user", "del_posts@example.com")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Deletion Post", "https://example.com/deletion")
reqBody := map[string]string{}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
requestBody := map[string]string{}
body, _ := json.Marshal(requestBody)
request := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, recorder, http.StatusOK)
if response == nil {
return
}
@@ -237,13 +237,13 @@ func TestIntegration_AccountDeletion(t *testing.T) {
"delete_posts": true,
}
confirmBodyBytes, _ := json.Marshal(confirmBody)
confirmReq := httptest.NewRequest("POST", "/api/auth/account/confirm", bytes.NewBuffer(confirmBodyBytes))
confirmReq.Header.Set("Content-Type", "application/json")
confirmRec := httptest.NewRecorder()
confirmRequest := httptest.NewRequest("POST", "/api/auth/account/confirm", bytes.NewBuffer(confirmBodyBytes))
confirmRequest.Header.Set("Content-Type", "application/json")
confirmRecorder := httptest.NewRecorder()
router.ServeHTTP(confirmRec, confirmReq)
router.ServeHTTP(confirmRecorder, confirmRequest)
confirmResponse := assertJSONResponse(t, confirmRec, http.StatusOK)
confirmResponse := assertJSONResponse(t, confirmRecorder, http.StatusOK)
if confirmResponse == nil {
return
}
@@ -275,12 +275,12 @@ func TestIntegration_MetricsCollection(t *testing.T) {
router := ctx.Router
t.Run("Metrics_Endpoint_Returns_Data", func(t *testing.T) {
req := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/metrics", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
response := assertJSONResponse(t, rec, http.StatusOK)
response := assertJSONResponse(t, recorder, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["database"]; !exists {
@@ -294,13 +294,13 @@ func TestIntegration_MetricsCollection(t *testing.T) {
ctx.Suite.EmailSender.Reset()
createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "metrics_user", "metrics@example.com")
req := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/metrics", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err == nil {
if err := json.NewDecoder(recorder.Body).Decode(&response); err == nil {
if data, ok := response["data"].(map[string]any); ok {
if dbData, exists := data["database"].(map[string]any); exists {
if _, hasQueries := dbData["total_queries"]; !hasQueries {
@@ -323,7 +323,7 @@ func TestIntegration_ConcurrentRequests(t *testing.T) {
var wg sync.WaitGroup
errors := make(chan error, 10)
for i := 0; i < 10; i++ {
for idx := range 10 {
wg.Add(1)
go func(index int) {
defer wg.Done()
@@ -334,18 +334,18 @@ func TestIntegration_ConcurrentRequests(t *testing.T) {
"content": "Concurrent test content",
}
body, _ := json.Marshal(postBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Code != http.StatusCreated {
errors <- fmt.Errorf("Post %d failed with status %d", index, rec.Code)
if recorder.Code != http.StatusCreated {
errors <- fmt.Errorf("Post %d failed with status %d", index, recorder.Code)
}
}(i)
}(idx)
}
wg.Wait()
@@ -370,28 +370,26 @@ func TestIntegration_ConcurrentRequests(t *testing.T) {
var wg sync.WaitGroup
errors := make(chan error, 5)
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for range 5 {
wg.Go(func() {
voteBody := map[string]string{
"type": "up",
}
body, _ := json.Marshal(voteBody)
req := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
request := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
request = testutils.WithURLParams(request, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Code != http.StatusOK {
errors <- fmt.Errorf("Vote failed with status %d", rec.Code)
if recorder.Code != http.StatusOK {
errors <- fmt.Errorf("Vote failed with status %d", recorder.Code)
}
}()
})
}
wg.Wait()
@@ -411,20 +409,18 @@ func TestIntegration_ConcurrentRequests(t *testing.T) {
var wg sync.WaitGroup
errors := make(chan error, 20)
for i := 0; i < 20; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for range 20 {
wg.Go(func() {
req := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/api/posts", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req)
router.ServeHTTP(recorder, request)
if rec.Code != http.StatusOK {
errors <- fmt.Errorf("Read failed with status %d", rec.Code)
if recorder.Code != http.StatusOK {
errors <- fmt.Errorf("Read failed with status %d", recorder.Code)
}
}()
})
}
wg.Wait()

View File

@@ -38,6 +38,10 @@ func CompressionMiddlewareWithConfig(config *CompressionConfig) func(http.Handle
next.ServeHTTP(bufferedWriter, r)
if bufferedWriter.isRedirect {
return
}
if buf.Len() < config.MinSize {
bufferedWriter.flush()
w.Write(buf.Bytes())
@@ -73,9 +77,13 @@ type bufferedResponseWriter struct {
buffer *bytes.Buffer
statusCode int
headerWritten bool
isRedirect bool
}
func (brw *bufferedResponseWriter) Write(b []byte) (int, error) {
if brw.isRedirect {
return brw.ResponseWriter.Write(b)
}
if !brw.headerWritten {
brw.statusCode = http.StatusOK
}
@@ -87,6 +95,11 @@ func (brw *bufferedResponseWriter) WriteHeader(code int) {
return
}
brw.statusCode = code
if isRedirect(code) {
brw.isRedirect = true
brw.ResponseWriter.WriteHeader(code)
brw.headerWritten = true
}
}
func (brw *bufferedResponseWriter) Header() http.Header {
@@ -100,6 +113,10 @@ func (brw *bufferedResponseWriter) flush() {
}
}
func isRedirect(statusCode int) bool {
return statusCode >= 300 && statusCode < 400
}
func shouldCompress(r *http.Request, config *CompressionConfig) bool {
return r.Header.Get("Content-Encoding") == ""
}

View File

@@ -27,7 +27,14 @@ func ValidationMiddleware() func(http.Handler) http.Handler {
dto := reflect.New(dtoType).Interface()
if err := json.NewDecoder(r.Body).Decode(dto); err != nil {
http.Error(w, "Invalid JSON", http.StatusBadRequest)
response := map[string]any{
"success": false,
"error": "Invalid JSON",
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(response)
return
}
@@ -77,3 +84,7 @@ func GetDTOTypeFromContext(ctx context.Context) reflect.Type {
func GetValidatedDTOFromContext(ctx context.Context) any {
return ctx.Value(validatedDTOKey)
}
func SetValidatedDTOInContext(ctx context.Context, dto any) context.Context {
return context.WithValue(ctx, validatedDTOKey, dto)
}

View File

@@ -28,6 +28,7 @@ type UserRepository interface {
Unlock(id uint) error
GetPosts(userID uint, limit, offset int) ([]database.Post, error)
GetDeletedUsers() ([]database.User, error)
GetByUsernamePrefix(prefix string) (*database.User, error)
HardDeleteAll() (int64, error)
Count() (int64, error)
WithTx(tx *gorm.DB) UserRepository
@@ -240,6 +241,17 @@ func (r *userRepository) GetDeletedUsers() ([]database.User, error) {
return users, err
}
func (r *userRepository) GetByUsernamePrefix(prefix string) (*database.User, error) {
var user database.User
err := r.db.
Where("username LIKE ? AND email LIKE ?", prefix+"%", prefix+"%@goyco.local").
First(&user).Error
if err != nil {
return nil, err
}
return &user, nil
}
func (r *userRepository) HardDeleteAll() (int64, error) {
var totalDeleted int64
err := r.db.Transaction(func(tx *gorm.DB) error {

View File

@@ -20,6 +20,7 @@ type VoteRepository interface {
Count() (int64, error)
CountByPostID(postID uint) (int64, error)
CountByUserID(userID uint) (int64, error)
GetVoteCountsByPostID(postID uint) (upVotes int, downVotes int, err error)
WithTx(tx *gorm.DB) VoteRepository
}
@@ -144,3 +145,20 @@ func (r *voteRepository) Count() (int64, error) {
err := r.db.Model(&database.Vote{}).Count(&count).Error
return count, err
}
func (r *voteRepository) GetVoteCountsByPostID(postID uint) (int, int, error) {
var result struct {
UpVotes int64
DownVotes int64
}
err := r.db.Model(&database.Vote{}).
Select("COUNT(CASE WHEN type = ? THEN 1 END) as up_votes, COUNT(CASE WHEN type = ? THEN 1 END) as down_votes", database.VoteUp, database.VoteDown).
Where("post_id = ?", postID).
Scan(&result).Error
if err != nil {
return 0, 0, err
}
return int(result.UpVotes), int(result.DownVotes), nil
}

View File

@@ -1,8 +1,10 @@
package server
import (
"mime"
"net/http"
"path/filepath"
"strings"
"time"
"goyco/internal/config"
@@ -73,6 +75,7 @@ func NewRouter(cfg RouterConfig) http.Handler {
},
CSRFMiddleware: middleware.CSRFMiddleware(),
AuthMiddleware: middleware.NewAuth(cfg.AuthService),
ValidationMiddleware: middleware.ValidationMiddleware(),
}
if cfg.PageHandler != nil {
@@ -123,7 +126,33 @@ func NewRouter(cfg RouterConfig) http.Handler {
staticDir = "./internal/static/"
}
router.Handle("/static/*", http.StripPrefix("/static/", http.FileServer(http.Dir(staticDir))))
staticFileServer := http.StripPrefix("/static/", http.FileServer(http.Dir(staticDir)))
router.Handle("/static/*", staticFileHandler(staticFileServer))
return router
}
func staticFileHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
ext := filepath.Ext(path)
if ext == ".css" {
w.Header().Set("Content-Type", "text/css; charset=utf-8")
} else if ext == ".js" {
w.Header().Set("Content-Type", "application/javascript; charset=utf-8")
} else if ext == ".json" {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
} else if ext == ".ico" {
w.Header().Set("Content-Type", "image/x-icon")
} else if strings.HasPrefix(mime.TypeByExtension(ext), "image/") {
w.Header().Set("Content-Type", mime.TypeByExtension(ext))
} else if strings.HasPrefix(mime.TypeByExtension(ext), "font/") {
w.Header().Set("Content-Type", mime.TypeByExtension(ext))
} else if mimeType := mime.TypeByExtension(ext); mimeType != "" {
w.Header().Set("Content-Type", mimeType)
}
next.ServeHTTP(w, r)
})
}

View File

@@ -3,6 +3,7 @@ package server
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"goyco/internal/config"
@@ -105,9 +106,9 @@ func defaultRateLimitConfig() config.RateLimitConfig {
return testutils.AppTestConfig.RateLimit
}
func TestAPIRootRouting(t *testing.T) {
func createDefaultRouterConfig() RouterConfig {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
return RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
@@ -115,7 +116,15 @@ func TestAPIRootRouting(t *testing.T) {
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
}
}
func createTestRouter(cfg RouterConfig) http.Handler {
return NewRouter(cfg)
}
func TestAPIRootRouting(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
@@ -141,23 +150,23 @@ func TestAPIRootRouting(t *testing.T) {
}
func TestProtectedRoutesRequireAuth(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
router := createTestRouter(createDefaultRouterConfig())
protectedRoutes := []struct {
method string
path string
}{
{http.MethodGet, "/api/auth/me"},
{http.MethodPost, "/api/auth/logout"},
{http.MethodPost, "/api/auth/revoke"},
{http.MethodPost, "/api/auth/revoke-all"},
{http.MethodPut, "/api/auth/email"},
{http.MethodPut, "/api/auth/username"},
{http.MethodPut, "/api/auth/password"},
{http.MethodDelete, "/api/auth/account"},
{http.MethodPost, "/api/posts"},
{http.MethodPut, "/api/posts/1"},
{http.MethodDelete, "/api/posts/1"},
{http.MethodPost, "/api/posts/1/vote"},
{http.MethodDelete, "/api/posts/1/vote"},
{http.MethodGet, "/api/posts/1/vote"},
@@ -183,17 +192,9 @@ func TestProtectedRoutesRequireAuth(t *testing.T) {
}
func TestRouterWithDebugMode(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
Debug: true,
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.Debug = true
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -206,16 +207,9 @@ func TestRouterWithDebugMode(t *testing.T) {
}
func TestRouterWithCacheDisabled(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
DisableCache: true,
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
})
cfg := createDefaultRouterConfig()
cfg.DisableCache = true
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -228,17 +222,9 @@ func TestRouterWithCacheDisabled(t *testing.T) {
}
func TestRouterWithCompressionDisabled(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
DisableCompression: true,
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.DisableCompression = true
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -251,19 +237,9 @@ func TestRouterWithCompressionDisabled(t *testing.T) {
}
func TestRouterWithCustomDBMonitor(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
customDBMonitor := middleware.NewInMemoryDBMonitor()
router := NewRouter(RouterConfig{
DBMonitor: customDBMonitor,
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.DBMonitor = middleware.NewInMemoryDBMonitor()
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -296,18 +272,9 @@ func TestRouterWithPageHandler(t *testing.T) {
}
func TestRouterWithStaticDir(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
StaticDir: "/custom/static/path",
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.StaticDir = "/custom/static/path"
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -320,18 +287,9 @@ func TestRouterWithStaticDir(t *testing.T) {
}
func TestRouterWithEmptyStaticDir(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
StaticDir: "",
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.StaticDir = ""
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -344,20 +302,11 @@ func TestRouterWithEmptyStaticDir(t *testing.T) {
}
func TestRouterWithAllFeaturesDisabled(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
Debug: true,
DisableCache: true,
DisableCompression: true,
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.Debug = true
cfg.DisableCache = true
cfg.DisableCompression = true
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -370,15 +319,9 @@ func TestRouterWithAllFeaturesDisabled(t *testing.T) {
}
func TestRouterWithoutAPIHandler(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, _, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.APIHandler = nil
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder()
@@ -391,17 +334,7 @@ func TestRouterWithoutAPIHandler(t *testing.T) {
}
func TestRouterWithoutPageHandler(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
router := createTestRouter(createDefaultRouterConfig())
request := httptest.NewRequest(http.MethodGet, "/", nil)
recorder := httptest.NewRecorder()
@@ -414,17 +347,7 @@ func TestRouterWithoutPageHandler(t *testing.T) {
}
func TestSwaggerRoute(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
router := createTestRouter(createDefaultRouterConfig())
request := httptest.NewRequest(http.MethodGet, "/swagger/", nil)
recorder := httptest.NewRecorder()
@@ -437,18 +360,9 @@ func TestSwaggerRoute(t *testing.T) {
}
func TestStaticFileRoute(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
StaticDir: "../../internal/static/",
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
cfg := createDefaultRouterConfig()
cfg.StaticDir = "../../internal/static/"
router := createTestRouter(cfg)
request := httptest.NewRequest(http.MethodGet, "/static/css/main.css", nil)
recorder := httptest.NewRecorder()
@@ -461,17 +375,7 @@ func TestStaticFileRoute(t *testing.T) {
}
func TestRouterConfiguration(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
router := createTestRouter(createDefaultRouterConfig())
if router == nil {
t.Error("Router should not be nil")
@@ -487,29 +391,484 @@ func TestRouterConfiguration(t *testing.T) {
}
}
func TestRouterMiddlewareIntegration(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
func TestAllRoutesExist(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
if router == nil {
t.Error("Router should not be nil")
publicRoutes := []struct {
method string
path string
description string
}{
{http.MethodGet, "/api", "API info"},
{http.MethodGet, "/health", "Health check"},
{http.MethodGet, "/metrics", "Metrics"},
{http.MethodGet, "/robots.txt", "Robots.txt"},
{http.MethodGet, "/api/posts", "Get posts"},
{http.MethodGet, "/api/posts/search", "Search posts"},
{http.MethodGet, "/api/posts/title", "Fetch title from URL"},
{http.MethodGet, "/api/posts/1", "Get post by ID"},
{http.MethodPost, "/api/auth/register", "Register"},
{http.MethodPost, "/api/auth/login", "Login"},
{http.MethodPost, "/api/auth/refresh", "Refresh token"},
{http.MethodGet, "/api/auth/confirm", "Confirm email"},
{http.MethodPost, "/api/auth/resend-verification", "Resend verification"},
{http.MethodPost, "/api/auth/forgot-password", "Forgot password"},
{http.MethodPost, "/api/auth/reset-password", "Reset password"},
{http.MethodPost, "/api/auth/account/confirm", "Confirm account deletion"},
}
request := httptest.NewRequest(http.MethodGet, "/api", nil)
protectedRoutes := []struct {
method string
path string
description string
}{
{http.MethodGet, "/api/auth/me", "Get current user"},
{http.MethodPost, "/api/auth/logout", "Logout"},
{http.MethodPost, "/api/auth/revoke", "Revoke token"},
{http.MethodPost, "/api/auth/revoke-all", "Revoke all tokens"},
{http.MethodPut, "/api/auth/email", "Update email"},
{http.MethodPut, "/api/auth/username", "Update username"},
{http.MethodPut, "/api/auth/password", "Update password"},
{http.MethodDelete, "/api/auth/account", "Delete account"},
{http.MethodPost, "/api/posts", "Create post"},
{http.MethodPut, "/api/posts/1", "Update post"},
{http.MethodDelete, "/api/posts/1", "Delete post"},
{http.MethodPost, "/api/posts/1/vote", "Cast vote"},
{http.MethodDelete, "/api/posts/1/vote", "Remove vote"},
{http.MethodGet, "/api/posts/1/vote", "Get user vote"},
{http.MethodGet, "/api/posts/1/votes", "Get post votes"},
{http.MethodGet, "/api/users", "Get users"},
{http.MethodPost, "/api/users", "Create user"},
{http.MethodGet, "/api/users/1", "Get user by ID"},
{http.MethodGet, "/api/users/1/posts", "Get user posts"},
}
for _, route := range publicRoutes {
t.Run(route.description+" "+route.method+" "+route.path, func(t *testing.T) {
invalidMethod := http.MethodPatch
switch route.method {
case http.MethodGet:
invalidMethod = http.MethodDelete
case http.MethodPost:
invalidMethod = http.MethodGet
}
request := httptest.NewRequest(invalidMethod, route.path, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code == 0 {
t.Error("Router should return a status code")
routeExists := recorder.Code == http.StatusMethodNotAllowed
if !routeExists {
request = httptest.NewRequest(route.method, route.path, nil)
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code == http.StatusNotFound && route.path != "/api/posts/1" && route.path != "/robots.txt" {
t.Errorf("Route %s %s should exist, got 404", route.method, route.path)
}
}
})
}
for _, route := range protectedRoutes {
t.Run(route.description+" "+route.method+" "+route.path, func(t *testing.T) {
request := httptest.NewRequest(route.method, route.path, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code == http.StatusNotFound {
t.Errorf("Route %s %s should exist, got 404", route.method, route.path)
}
if recorder.Code != http.StatusUnauthorized {
t.Errorf("Protected route %s %s should return 401 without auth, got %d", route.method, route.path, recorder.Code)
}
})
}
}
func TestRouteParameters(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
method string
pathPattern string
testIDs []string
isProtected bool
}{
{
name: "Get post by ID",
method: http.MethodGet,
pathPattern: "/api/posts/{id}",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: false,
},
{
name: "Update post by ID",
method: http.MethodPut,
pathPattern: "/api/posts/{id}",
testIDs: []string{"1", "42", "999"},
isProtected: true,
},
{
name: "Delete post by ID",
method: http.MethodDelete,
pathPattern: "/api/posts/{id}",
testIDs: []string{"1", "42", "999"},
isProtected: true,
},
{
name: "Get user by ID",
method: http.MethodGet,
pathPattern: "/api/users/{id}",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: true,
},
{
name: "Get user posts by user ID",
method: http.MethodGet,
pathPattern: "/api/users/{id}/posts",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: true,
},
{
name: "Cast vote for post ID",
method: http.MethodPost,
pathPattern: "/api/posts/{id}/vote",
testIDs: []string{"1", "42", "999"},
isProtected: true,
},
{
name: "Remove vote for post ID",
method: http.MethodDelete,
pathPattern: "/api/posts/{id}/vote",
testIDs: []string{"1", "42", "999"},
isProtected: true,
},
{
name: "Get user vote for post ID",
method: http.MethodGet,
pathPattern: "/api/posts/{id}/vote",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: true,
},
{
name: "Get post votes by post ID",
method: http.MethodGet,
pathPattern: "/api/posts/{id}/votes",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, id := range tc.testIDs {
path := replaceID(tc.pathPattern, id)
t.Run("ID_"+id, func(t *testing.T) {
request := httptest.NewRequest(http.MethodPatch, path, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
routeExists := recorder.Code == http.StatusMethodNotAllowed
request = httptest.NewRequest(tc.method, path, nil)
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if !routeExists {
if recorder.Code == http.StatusNotFound {
t.Errorf("Route %s %s should exist with ID %s, got 404", tc.method, path, id)
return
}
}
if tc.isProtected {
if recorder.Code != http.StatusUnauthorized {
t.Errorf("Protected route %s %s should return 401 without auth, got %d", tc.method, path, recorder.Code)
}
}
})
}
})
}
}
func replaceID(pattern, id string) string {
return strings.Replace(pattern, "{id}", id, 1)
}
func TestInvalidRouteParameters(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
method string
path string
expectedMin int
expectedMax int
isProtected bool
allow401 bool
}{
{
name: "Non-numeric post ID",
method: http.MethodGet,
path: "/api/posts/abc",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusBadRequest,
isProtected: false,
},
{
name: "Negative post ID",
method: http.MethodGet,
path: "/api/posts/-1",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusBadRequest,
isProtected: false,
},
{
name: "Zero post ID",
method: http.MethodGet,
path: "/api/posts/0",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusNotFound,
isProtected: false,
},
{
name: "Post ID with special characters",
method: http.MethodGet,
path: "/api/posts/123@456",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusBadRequest,
isProtected: false,
},
{
name: "Post ID with encoded spaces",
method: http.MethodGet,
path: "/api/posts/12%2034",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusBadRequest,
isProtected: false,
},
{
name: "Non-numeric user ID",
method: http.MethodGet,
path: "/api/users/xyz",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusUnauthorized,
isProtected: true,
allow401: true,
},
{
name: "Negative user ID",
method: http.MethodGet,
path: "/api/users/-5",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusUnauthorized,
isProtected: true,
allow401: true,
},
{
name: "Non-numeric post ID in vote route",
method: http.MethodGet,
path: "/api/posts/invalid/vote",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusUnauthorized,
isProtected: true,
allow401: true,
},
{
name: "Very large post ID",
method: http.MethodGet,
path: "/api/posts/999999999999",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusNotFound,
isProtected: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
request := httptest.NewRequest(tc.method, tc.path, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if tc.isProtected && tc.allow401 {
if recorder.Code != http.StatusUnauthorized && (recorder.Code < tc.expectedMin || recorder.Code > tc.expectedMax) {
t.Errorf("Protected route %s %s with invalid parameter should return 401 or status between %d and %d, got %d", tc.method, tc.path, tc.expectedMin, tc.expectedMax, recorder.Code)
}
} else {
if recorder.Code < tc.expectedMin || recorder.Code > tc.expectedMax {
t.Errorf("Route %s %s should return status between %d and %d, got %d", tc.method, tc.path, tc.expectedMin, tc.expectedMax, recorder.Code)
}
if recorder.Code != http.StatusNotFound && recorder.Code < 400 {
t.Errorf("Route %s %s with invalid parameter should return error status (4xx), got %d", tc.method, tc.path, recorder.Code)
}
}
})
}
}
func TestQueryParameters(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
method string
path string
queryParams string
expectRoute bool
}{
{
name: "Get posts with limit and offset",
method: http.MethodGet,
path: "/api/posts",
queryParams: "limit=10&offset=5",
expectRoute: true,
},
{
name: "Get posts with only limit",
method: http.MethodGet,
path: "/api/posts",
queryParams: "limit=20",
expectRoute: true,
},
{
name: "Get posts with only offset",
method: http.MethodGet,
path: "/api/posts",
queryParams: "offset=10",
expectRoute: true,
},
{
name: "Search posts with query parameter",
method: http.MethodGet,
path: "/api/posts/search",
queryParams: "q=test",
expectRoute: true,
},
{
name: "Search posts with query, limit, and offset",
method: http.MethodGet,
path: "/api/posts/search",
queryParams: "q=test&limit=15&offset=3",
expectRoute: true,
},
{
name: "Fetch title with URL parameter",
method: http.MethodGet,
path: "/api/posts/title",
queryParams: "url=https://example.com",
expectRoute: true,
},
{
name: "Confirm email with token parameter",
method: http.MethodGet,
path: "/api/auth/confirm",
queryParams: "token=abc123",
expectRoute: true,
},
{
name: "Get posts with invalid limit",
method: http.MethodGet,
path: "/api/posts",
queryParams: "limit=abc",
expectRoute: true,
},
{
name: "Get posts with negative limit",
method: http.MethodGet,
path: "/api/posts",
queryParams: "limit=-5",
expectRoute: true,
},
{
name: "Get posts with negative offset",
method: http.MethodGet,
path: "/api/posts",
queryParams: "offset=-10",
expectRoute: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
fullPath := tc.path
if tc.queryParams != "" {
fullPath += "?" + tc.queryParams
}
request := httptest.NewRequest(tc.method, fullPath, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if tc.expectRoute {
if recorder.Code == http.StatusNotFound {
t.Errorf("Route %s %s should exist with query parameters, got 404", tc.method, fullPath)
}
}
})
}
}
func TestRouteConflicts(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
method string
path string
description string
}{
{
name: "posts/search should not match posts/{id}",
method: http.MethodGet,
path: "/api/posts/search",
description: "search route should be matched, not treated as ID",
},
{
name: "posts/title should not match posts/{id}",
method: http.MethodGet,
path: "/api/posts/title",
description: "title route should be matched, not treated as ID",
},
{
name: "posts/{id} should work with numeric ID",
method: http.MethodGet,
path: "/api/posts/123",
description: "numeric ID should match {id} route",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
request := httptest.NewRequest(tc.method, tc.path, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
switch tc.path {
case "/api/posts/search":
if recorder.Code == http.StatusNotFound {
t.Errorf("%s: Route %s %s should exist (not 404), got %d", tc.description, tc.method, tc.path, recorder.Code)
}
case "/api/posts/title":
if recorder.Code == http.StatusNotFound {
t.Errorf("%s: Route %s %s should exist (not 404), got %d", tc.description, tc.method, tc.path, recorder.Code)
}
case "/api/posts/123":
if recorder.Code == http.StatusNotFound {
return
}
if recorder.Code < 400 {
t.Errorf("%s: Route %s %s should return 4xx or 5xx, got %d", tc.description, tc.method, tc.path, recorder.Code)
}
}
})
}
}

View File

@@ -1,12 +1,93 @@
package services
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"net/mail"
"strings"
"goyco/internal/config"
"goyco/internal/database"
)
const (
defaultTokenExpirationHours = 24
verificationTokenBytes = 32
deletionTokenExpirationHours = 24
)
var (
ErrInvalidCredentials = errors.New("invalid credentials")
ErrInvalidToken = errors.New("invalid or expired token")
ErrUsernameTaken = errors.New("username already exists")
ErrEmailTaken = errors.New("email already exists")
ErrInvalidEmail = errors.New("invalid email address")
ErrPasswordTooShort = errors.New("password too short")
ErrEmailNotVerified = errors.New("email not verified")
ErrAccountLocked = errors.New("account is locked")
ErrInvalidVerificationToken = errors.New("invalid verification token")
ErrEmailSenderUnavailable = errors.New("email sender not configured")
ErrDeletionEmailFailed = errors.New("account deletion email failed")
ErrInvalidDeletionToken = errors.New("invalid account deletion token")
ErrUserNotFound = errors.New("user not found")
ErrDeletionRequestNotFound = errors.New("deletion request not found")
)
type AuthResult struct {
AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
RefreshToken string `json:"refresh_token" example:"a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6q7r8s9t0"`
User *database.User `json:"user"`
}
type RegistrationResult struct {
User *database.User `json:"user"`
VerificationSent bool `json:"verification_sent"`
}
func normalizeEmail(email string) (string, error) {
trimmed := strings.TrimSpace(email)
if trimmed == "" {
return "", fmt.Errorf("email is required")
}
parsed, err := mail.ParseAddress(trimmed)
if err != nil {
return "", ErrInvalidEmail
}
return strings.ToLower(parsed.Address), nil
}
func generateVerificationToken() (string, string, error) {
buf := make([]byte, verificationTokenBytes)
if _, err := rand.Read(buf); err != nil {
return "", "", fmt.Errorf("generate verification token: %w", err)
}
token := hex.EncodeToString(buf)
hashed := HashVerificationToken(token)
return token, hashed, nil
}
func HashVerificationToken(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}
func sanitizeUser(user *database.User) *database.User {
if user == nil {
return nil
}
copy := *user
copy.Password = ""
copy.EmailVerificationToken = ""
return &copy
}
type AuthFacade struct {
registrationService *RegistrationService
passwordResetService *PasswordResetService

View File

@@ -1,35 +0,0 @@
package services
import (
"errors"
"goyco/internal/database"
)
var (
ErrInvalidCredentials = errors.New("invalid credentials")
ErrInvalidToken = errors.New("invalid or expired token")
ErrUsernameTaken = errors.New("username already exists")
ErrEmailTaken = errors.New("email already exists")
ErrInvalidEmail = errors.New("invalid email address")
ErrPasswordTooShort = errors.New("password too short")
ErrEmailNotVerified = errors.New("email not verified")
ErrAccountLocked = errors.New("account is locked")
ErrInvalidVerificationToken = errors.New("invalid verification token")
ErrEmailSenderUnavailable = errors.New("email sender not configured")
ErrDeletionEmailFailed = errors.New("account deletion email failed")
ErrInvalidDeletionToken = errors.New("invalid account deletion token")
ErrUserNotFound = errors.New("user not found")
ErrDeletionRequestNotFound = errors.New("deletion request not found")
)
type AuthResult struct {
AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
RefreshToken string `json:"refresh_token" example:"a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6q7r8s9t0"`
User *database.User `json:"user"`
}
type RegistrationResult struct {
User *database.User `json:"user"`
VerificationSent bool `json:"verification_sent"`
}

View File

@@ -1,59 +0,0 @@
package services
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/mail"
"strings"
"goyco/internal/database"
)
const (
defaultTokenExpirationHours = 24
verificationTokenBytes = 32
deletionTokenExpirationHours = 24
)
func normalizeEmail(email string) (string, error) {
trimmed := strings.TrimSpace(email)
if trimmed == "" {
return "", fmt.Errorf("email is required")
}
parsed, err := mail.ParseAddress(trimmed)
if err != nil {
return "", ErrInvalidEmail
}
return strings.ToLower(parsed.Address), nil
}
func generateVerificationToken() (string, string, error) {
buf := make([]byte, verificationTokenBytes)
if _, err := rand.Read(buf); err != nil {
return "", "", fmt.Errorf("generate verification token: %w", err)
}
token := hex.EncodeToString(buf)
hashed := HashVerificationToken(token)
return token, hashed, nil
}
func HashVerificationToken(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}
func sanitizeUser(user *database.User) *database.User {
if user == nil {
return nil
}
copy := *user
copy.Password = ""
copy.EmailVerificationToken = ""
return &copy
}

View File

@@ -7,9 +7,10 @@ import (
"sync"
"testing"
"gorm.io/gorm"
"goyco/internal/database"
"goyco/internal/repositories"
"gorm.io/gorm"
)
type mockVoteRepo struct {
@@ -240,6 +241,25 @@ func (m *mockVoteRepo) CountByUserID(userID uint) (int64, error) {
return count, nil
}
func (m *mockVoteRepo) GetVoteCountsByPostID(postID uint) (int, int, error) {
m.mu.RLock()
defer m.mu.RUnlock()
upVotes := 0
downVotes := 0
for _, vote := range m.votes {
if vote.PostID == postID {
switch vote.Type {
case database.VoteUp:
upVotes++
case database.VoteDown:
downVotes++
}
}
}
return upVotes, downVotes, nil
}
func (m *mockVoteRepo) WithTx(tx *gorm.DB) repositories.VoteRepository {
return m
}

View File

@@ -2,43 +2,70 @@ package templates
import (
"html/template"
"io/fs"
"path/filepath"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestTemplateParsing(t *testing.T) {
templateDir := "./"
func templateFuncMap() template.FuncMap {
return template.FuncMap{
"formatTime": func(t time.Time) string {
if t.IsZero() {
return ""
}
return t.Format("02 Jan 2006 15:04")
},
"truncate": func(s string, length int) string {
if len(s) <= length {
return s
}
if length <= 3 {
return s[:length]
}
return s[:length-3] + "..."
},
"substr": func(s string, start, length int) string {
if start >= len(s) {
return ""
}
end := min(start+length, len(s))
return s[start:end]
},
"upper": strings.ToUpper,
}
}
var templateFiles []string
err := filepath.WalkDir(templateDir, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if !d.IsDir() && filepath.Ext(path) == ".gohtml" {
templateFiles = append(templateFiles, path)
}
return nil
})
func TestTemplateParsing(t *testing.T) {
layoutPath := filepath.Join(".", "base.gohtml")
require.FileExists(t, layoutPath, "base layout is required for all templates")
partials, err := filepath.Glob(filepath.Join(".", "partials", "*.gohtml"))
require.NoError(t, err)
tmpl := template.New("test")
pages, err := filepath.Glob(filepath.Join(".", "*.gohtml"))
require.NoError(t, err)
require.NotEmpty(t, pages, "no page templates found")
tmpl = tmpl.Funcs(template.FuncMap{
"formatTime": func(any) string { return "2024-01-01" },
"eq": func(a, b any) bool { return a == b },
"ne": func(a, b any) bool { return a != b },
"len": func(s any) int { return 0 },
"range": func(s any) any { return s },
})
for _, page := range pages {
if filepath.Base(page) == "base.gohtml" {
continue
}
for _, file := range templateFiles {
t.Run(file, func(t *testing.T) {
_, err := tmpl.ParseFiles(file)
assert.NoError(t, err, "Template %s should parse without errors", file)
page := page
t.Run(filepath.Base(page), func(t *testing.T) {
t.Parallel()
files := append([]string{layoutPath}, partials...)
files = append(files, page)
tmpl, err := template.New(filepath.Base(page)).Funcs(templateFuncMap()).ParseFiles(files...)
require.NoError(t, err)
require.NotNil(t, tmpl.Lookup("layout"), "layout template should be available")
require.NotNil(t, tmpl.Lookup("content"), "content block should be defined by page templates")
})
}
}

View File

@@ -422,6 +422,24 @@ func (m *MockUserRepository) GetDeletedUsers() ([]database.User, error) {
return []database.User{}, nil
}
func (m *MockUserRepository) GetByUsernamePrefix(prefix string) (*database.User, error) {
m.mu.RLock()
defer m.mu.RUnlock()
for _, user := range m.users {
if len(user.Username) >= len(prefix) && user.Username[:len(prefix)] == prefix {
if len(user.Email) >= 13 && strings.HasSuffix(user.Email, "@goyco.local") {
emailPrefix := user.Email[:len(user.Email)-13]
if len(emailPrefix) >= len(prefix) && emailPrefix[:len(prefix)] == prefix {
return user, nil
}
}
}
}
return nil, gorm.ErrRecordNotFound
}
func (m *MockUserRepository) HardDeleteAll() (int64, error) {
if m.HardDeleteAllFunc != nil {
return m.HardDeleteAllFunc()
@@ -965,6 +983,25 @@ func (m *MockVoteRepository) CountByUserID(userID uint) (int64, error) {
return count, nil
}
func (m *MockVoteRepository) GetVoteCountsByPostID(postID uint) (int, int, error) {
m.mu.RLock()
defer m.mu.RUnlock()
upVotes := 0
downVotes := 0
for _, vote := range m.votes {
if vote.PostID == postID {
switch vote.Type {
case database.VoteUp:
upVotes++
case database.VoteDown:
downVotes++
}
}
}
return upVotes, downVotes, nil
}
func (m *MockVoteRepository) Count() (int64, error) {
m.mu.RLock()
defer m.mu.RUnlock()

View File

@@ -153,6 +153,7 @@ type UserRepositoryStub struct {
UnlockFn func(uint) error
GetPostsFn func(uint, int, int) ([]database.Post, error)
GetDeletedUsersFn func() ([]database.User, error)
GetByUsernamePrefixFn func(string) (*database.User, error)
HardDeleteAllFn func() (int64, error)
CountFn func() (int64, error)
WithTxFn func(*gorm.DB) repositories.UserRepository
@@ -281,6 +282,13 @@ func (s *UserRepositoryStub) GetDeletedUsers() ([]database.User, error) {
return nil, nil
}
func (s *UserRepositoryStub) GetByUsernamePrefix(prefix string) (*database.User, error) {
if s != nil && s.GetByUsernamePrefixFn != nil {
return s.GetByUsernamePrefixFn(prefix)
}
return nil, gorm.ErrRecordNotFound
}
func (s *UserRepositoryStub) HardDeleteAll() (int64, error) {
if s != nil && s.HardDeleteAllFn != nil {
return s.HardDeleteAllFn()

View File

@@ -30,7 +30,7 @@ if [ ! -d "docs" ]; then
mkdir -p docs
fi
SWAGGER_DIRECTORIES="cmd/goyco,internal/handlers"
SWAGGER_DIRECTORIES="cmd/goyco,internal/handlers,internal/dto"
SWAGGER_MAIN_FILE="main.go"
SWAGGER_OUTPUT_DIR="docs"

View File

@@ -6,9 +6,9 @@ if [ "$EUID" -ne 0 ]; then
exit 1
fi
read -s "Do you want to install PostgreSQL 17? [y/N] " INSTALL_PG
read -s "Do you want to install PostgreSQL 18? [y/N] " INSTALL_PG
if [ "$INSTALL_PG" != "y" ]; then
echo "PostgreSQL 17 will not be installed"
echo "PostgreSQL 18 will not be installed"
exit 0
fi
@@ -16,7 +16,7 @@ read -s -p "Enter password for PostgreSQL user 'goyco': " GOYCO_PWD
echo
apt-get update
apt-get install -y postgresql-17
apt-get install -y postgresql-18
systemctl enable --now postgresql
@@ -43,6 +43,4 @@ END
GRANT ALL PRIVILEGES ON DATABASE goyco TO goyco;
EOF
echo "PostgreSQL 17 installed, database 'goyco' and user 'goyco' set up."
echo "PostgreSQL 18 installed, database 'goyco' and user 'goyco' set up."