Compare commits

..

57 Commits

Author SHA1 Message Date
65109a787c feat: use GetVersion() 2026-01-26 22:17:14 +01:00
75f1406edf feat: use a getter 2026-01-26 22:17:02 +01:00
11dc9b507f feat: bump version to 0.1.1 2026-01-19 21:07:39 +01:00
da616438e9 chore: update version in swagger 2026-01-19 21:07:30 +01:00
7486865343 lint: remove duplicate string literals in seed tests 2026-01-19 16:43:51 +01:00
fd0fd8954a fix: close captureOutput pipe before read 2026-01-19 16:37:22 +01:00
628db14f59 fix: avoid Update deadlock by unlocking before display 2026-01-19 16:37:15 +01:00
7be196e4c3 test: move seed RNG to tests and add help/error cases 2026-01-19 16:37:01 +01:00
2f4bd45efb feat: make seed transactional and sequential with helpers 2026-01-19 16:36:51 +01:00
1b53c2b66b clean: get rid of parallel processor 2026-01-19 16:36:40 +01:00
509e68f538 docs: review roadmap 2026-01-16 11:23:27 +01:00
e6a44d830e fix: avoid repeated string concatenation 2026-01-14 17:05:20 +01:00
fe396b7537 feat: scope help printer to root command run 2026-01-14 13:00:03 +01:00
6eb04aa3c5 refactor: adapt test name 2026-01-14 12:59:14 +01:00
517d4482c9 test: fuzz urfave command path 2026-01-13 07:58:08 +01:00
b6e2bf942a tests: drive cli via urfave root command 2026-01-13 07:57:48 +01:00
9f1058ba81 tests: assert server fields and use urfave cli 2026-01-13 07:57:37 +01:00
2bdbb29ae6 refactor: remove legacy dispatch 2026-01-13 07:57:26 +01:00
9d243a0ed1 docs: mark cli migration as complete 2026-01-13 07:46:38 +01:00
9c74828b8d tests: fuzz urfave command parsing 2026-01-13 07:46:30 +01:00
9e78477eb5 tests: update cli help/json checks 2026-01-13 07:46:23 +01:00
a74980caa1 deps: add urfave/cli v3 checksums 2026-01-13 07:46:05 +01:00
816f08a20a deps: add urfave/cli v3 2026-01-13 07:45:53 +01:00
0cec152486 feat: migrate cli to urfave/cli v3 2026-01-13 07:45:18 +01:00
5413737491 test: match validation error casing with json tags 2026-01-12 22:49:40 +01:00
5f605e45c7 test: align title validation errors with json tags 2026-01-12 22:49:30 +01:00
e5779183ff test: cover json tag display and whitespace required case 2026-01-12 22:49:17 +01:00
4814b64c2c refactor: improve validation messages and string handling 2026-01-12 22:49:08 +01:00
45cad505d6 fix: break import cycle by inlining fuzz helpers 2026-01-12 22:40:12 +01:00
7f52347854 fix: enable foreign keys before AutoMigrate in fuzz DB 2026-01-12 22:37:54 +01:00
542913cbef fix: enable foreign key enforcement in fuzz DB 2026-01-12 22:36:46 +01:00
2f964b0c79 fix: prevent schema drift in fuzz tests with AutoMigrate 2026-01-12 22:35:56 +01:00
250ff79eeb test: update TestGetFuzzDB to expect new DB instances per call 2026-01-12 22:34:44 +01:00
4dfe260953 fix: remove global sync.Once to prevent DB state leakage in fuzz tests 2026-01-12 22:34:36 +01:00
49e6bb1e9d test: simplify pagination test loops 2026-01-12 12:26:26 +01:00
5b0c6018c0 test: cover pagination 2026-01-12 12:24:50 +01:00
3303d13f15 refactor: move TestApplyPagination to its own file 2026-01-12 12:24:42 +01:00
c1746eb346 docs: update readme 2026-01-10 23:03:07 +01:00
e2804ca07e refactor: use GetValidatedDTO for vote validation 2026-01-10 23:01:15 +01:00
6cdad79caa refactor: req -> request 2026-01-10 23:00:59 +01:00
6227b64746 refactor: use GetValidatedDTO for user create validation 2026-01-10 23:00:08 +01:00
506e233347 test: adjust post creation tests for DTO validation 2026-01-10 22:59:56 +01:00
8c06c916e1 refactor: use GetValidatedDTO for request validation 2026-01-10 22:59:43 +01:00
29fcaab25d refactor: centralize DTO decode/validation and skip duplicate validation 2026-01-10 22:59:29 +01:00
422ff2473e test: align auth handler expectations with validation errors 2026-01-10 22:59:21 +01:00
dbe4879457 refactor: route validation errors through GetValidatedDTO 2026-01-10 22:59:13 +01:00
5a530b7609 docs: update swagger 2026-01-10 22:58:58 +01:00
66b4b0e173 refactor: use AuthResponseDTO instead of services.AuthResult 2026-01-10 22:47:10 +01:00
e08e2b3189 feat: add AuthResponseDTO for login and refresh token responses 2026-01-10 22:46:56 +01:00
f39dcff67d refactor: use DTOs instead of maps and nil for responses 2026-01-10 22:44:39 +01:00
08d8d0ed22 feat: add MessageResponseDTO and EmptyResponseDTO 2026-01-10 22:44:07 +01:00
932c042aa2 feat: add TitleResponseDTO for title fetch responses 2026-01-10 22:43:57 +01:00
a1466e860d refactor: remove redundant validation, trust middleware and service layer 2026-01-10 22:41:54 +01:00
a1e63b868f refactor: remove redundant validation and unused import 2026-01-10 22:41:43 +01:00
b6f5293c0f refactor: remove redundant validation 2026-01-10 22:41:31 +01:00
6643466d76 refactor: use ToSanitizedUserListDTO and ToPostListDTO helpers 2026-01-10 22:39:04 +01:00
9dcf748474 refactor: use ToPostListDTO and ToSearchPostListDTO helpers 2026-01-10 22:38:54 +01:00
38 changed files with 1269 additions and 1758 deletions

View File

@@ -203,25 +203,44 @@ It'll be more readable and easier to parse.
- `POST /api/auth/register` - Register new user
- `POST /api/auth/login` - Login user
- `GET /api/auth/confirm` - Confirm email
- `POST /api/auth/logout` - Logout user
- `GET /api/auth/confirm` - Confirm email address
- `POST /api/auth/resend-verification` - Resend verification email
- `POST /api/auth/forgot-password` - Request password reset
- `POST /api/auth/reset-password` - Reset password
- `GET /api/auth/me` - Get current user profile (protected)
- `POST /api/auth/logout` - Logout user (protected)
- `POST /api/auth/refresh` - Refresh access token (rotates refresh token)
- `POST /api/auth/revoke` - Revoke a refresh token
- `POST /api/auth/revoke-all` - Revoke all refresh tokens for the current user
- `POST /api/auth/revoke` - Revoke a refresh token (protected)
- `POST /api/auth/revoke-all` - Revoke all refresh tokens for the current user (protected)
- `PUT /api/auth/email` - Update email address (protected)
- `PUT /api/auth/username` - Update username (protected)
- `PUT /api/auth/password` - Update password (protected)
- `DELETE /api/auth/account` - Request account deletion (protected)
- `POST /api/auth/account/confirm` - Confirm account deletion
#### Posts
- `GET /api/posts` - List posts
- `POST /api/posts` - Create post
- `GET /api/posts/search` - Search posts
- `GET /api/posts/title` - Fetch title from URL
- `GET /api/posts/{id}` - Get specific post
- `PUT /api/posts/{id}` - Update post
- `DELETE /api/posts/{id}` - Delete post
- `POST /api/posts` - Create post (protected)
- `PUT /api/posts/{id}` - Update post (protected)
- `DELETE /api/posts/{id}` - Delete post (protected)
#### Voting
- `POST /api/posts/{id}/vote` - Cast vote
- `DELETE /api/posts/{id}/vote` - Remove vote
- `GET /api/posts/{id}/votes` - Get post votes
- `POST /api/posts/{id}/vote` - Cast vote (protected)
- `DELETE /api/posts/{id}/vote` - Remove vote (protected)
- `GET /api/posts/{id}/vote` - Get current user's vote (protected)
- `GET /api/posts/{id}/votes` - Get all votes for post
#### Users
- `GET /api/users` - List all users (protected)
- `POST /api/users` - Create new user (protected)
- `GET /api/users/{id}` - Get specific user
- `GET /api/users/{id}/posts` - Get user's posts (protected)
## CLI Commands
@@ -400,14 +419,10 @@ This will regenerate the swagger documentation and update the `docs/swagger.json
## Roadmap
- [ ] migrate cli to urfave/cli
- [x] migrate cli to urfave/cli
- [ ] add a ML powered nsfw link detection
- [ ] add right management within the app
- [ ] add an admin backoffice to manage rights, users, content and settings
- [ ] add a way to run read-only communities
- [ ] migrate raw CSS to UnoCSS
- [ ] kubernetes deployment
- [ ] store configuration in the database
## License

View File

@@ -1,14 +1,24 @@
package main
import (
"context"
"errors"
"flag"
"fmt"
"io"
"os"
"sync"
"goyco/cmd/goyco/commands"
"goyco/internal/config"
"github.com/joho/godotenv"
"github.com/urfave/cli/v3"
)
var (
helpPrinterOnce sync.Once
defaultHelpPrinter func(io.Writer, string, interface{})
)
func loadDotEnv() {
@@ -55,3 +65,125 @@ func printRunUsage() {
fmt.Fprintln(os.Stderr, "Usage: goyco run")
fmt.Fprintln(os.Stderr, "\nStart the web application in foreground.")
}
func buildRootCommand(cfg *config.Config) *cli.Command {
helpPrinterOnce.Do(func() {
defaultHelpPrinter = cli.HelpPrinter
})
cli.HelpPrinter = func(w io.Writer, templ string, data interface{}) {
if cmd, ok := data.(*cli.Command); ok && cmd.Root() == cmd {
printRootUsage()
return
}
defaultHelpPrinter(w, templ, data)
}
root := &cli.Command{
Name: "goyco",
Usage: "Y Combinator-style news aggregation platform API",
UsageText: "goyco <command> [<args>]",
HideVersion: true,
Flags: []cli.Flag{
&cli.BoolFlag{
Name: "json",
Usage: "output results in JSON format",
Value: cfg.CLI.JSONOutputDefault,
},
},
Before: func(ctx context.Context, cmd *cli.Command) (context.Context, error) {
commands.SetJSONOutput(cmd.Bool("json"))
return ctx, nil
},
After: func(ctx context.Context, cmd *cli.Command) error {
cli.HelpPrinter = defaultHelpPrinter
return nil
},
Action: func(_ context.Context, cmd *cli.Command) error {
if cmd.NArg() == 0 {
printRootUsage()
return nil
}
printRootUsage()
return fmt.Errorf("unknown command: %s", cmd.Args().First())
},
Commands: []*cli.Command{
{
Name: "run",
Usage: "start the web application in foreground",
SkipFlagParsing: true,
Action: func(_ context.Context, cmd *cli.Command) error {
return handleRunCommand(cfg, cmd.Args().Slice())
},
},
{
Name: "start",
Usage: "start the web application in background",
SkipFlagParsing: true,
Action: func(_ context.Context, cmd *cli.Command) error {
return commands.HandleStartCommand(cfg, cmd.Args().Slice())
},
},
{
Name: "stop",
Usage: "stop the daemon",
SkipFlagParsing: true,
Action: func(_ context.Context, cmd *cli.Command) error {
return commands.HandleStopCommand(cfg, cmd.Args().Slice())
},
},
{
Name: "status",
Usage: "check if the daemon is running",
SkipFlagParsing: true,
Action: func(_ context.Context, cmd *cli.Command) error {
return commands.HandleStatusCommand(cfg, cmd.Name, cmd.Args().Slice())
},
},
{
Name: "migrate",
Aliases: []string{"migrations"},
Usage: "apply database migrations",
SkipFlagParsing: true,
Action: func(_ context.Context, cmd *cli.Command) error {
return commands.HandleMigrateCommand(cfg, cmd.Name, cmd.Args().Slice())
},
},
{
Name: "user",
Usage: "manage users (create, update, delete, lock, list)",
SkipFlagParsing: true,
Action: func(_ context.Context, cmd *cli.Command) error {
return commands.HandleUserCommand(cfg, cmd.Name, cmd.Args().Slice())
},
},
{
Name: "post",
Usage: "manage posts (delete, list, search)",
SkipFlagParsing: true,
Action: func(_ context.Context, cmd *cli.Command) error {
return commands.HandlePostCommand(cfg, cmd.Name, cmd.Args().Slice())
},
},
{
Name: "prune",
Usage: "hard delete users and posts (posts, all)",
SkipFlagParsing: true,
Action: func(_ context.Context, cmd *cli.Command) error {
return commands.HandlePruneCommand(cfg, cmd.Name, cmd.Args().Slice())
},
},
{
Name: "seed",
Usage: "seed database with random data",
SkipFlagParsing: true,
Action: func(_ context.Context, cmd *cli.Command) error {
return commands.HandleSeedCommand(cfg, cmd.Name, cmd.Args().Slice())
},
},
},
Writer: os.Stdout,
ErrWriter: os.Stderr,
}
return root
}

View File

@@ -1,6 +1,7 @@
package main
import (
"context"
"errors"
"flag"
"os"
@@ -131,25 +132,26 @@ func TestPrintRunUsage(t *testing.T) {
printRunUsage()
}
func TestDispatchCommand(t *testing.T) {
func TestRootCommandDispatch(t *testing.T) {
t.Run("unknown command", func(t *testing.T) {
cfg := testutils.NewTestConfig()
err := dispatchCommand(cfg, "unknown", []string{})
cmd := buildRootCommand(cfg)
err := cmd.Run(context.Background(), []string{"goyco", "unknown"})
if err == nil {
t.Error("expected error for unknown command")
}
expectedErr := "unknown command: unknown"
if err.Error() != expectedErr {
if err != nil && err.Error() != expectedErr {
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
}
})
t.Run("help command", func(t *testing.T) {
cfg := testutils.NewTestConfig()
err := dispatchCommand(cfg, "help", []string{})
cmd := buildRootCommand(cfg)
err := cmd.Run(context.Background(), []string{"goyco", "help"})
if err != nil {
t.Errorf("unexpected error for help command: %v", err)
@@ -158,7 +160,8 @@ func TestDispatchCommand(t *testing.T) {
t.Run("h command", func(t *testing.T) {
cfg := testutils.NewTestConfig()
err := dispatchCommand(cfg, "-h", []string{})
cmd := buildRootCommand(cfg)
err := cmd.Run(context.Background(), []string{"goyco", "-h"})
if err != nil {
t.Errorf("unexpected error for -h command: %v", err)
@@ -167,7 +170,8 @@ func TestDispatchCommand(t *testing.T) {
t.Run("--help command", func(t *testing.T) {
cfg := testutils.NewTestConfig()
err := dispatchCommand(cfg, "--help", []string{})
cmd := buildRootCommand(cfg)
err := cmd.Run(context.Background(), []string{"goyco", "--help"})
if err != nil {
t.Errorf("unexpected error for --help command: %v", err)
@@ -179,7 +183,8 @@ func TestDispatchCommand(t *testing.T) {
useInMemoryCommandsConnector(t)
err := dispatchCommand(cfg, "post", []string{"list"})
cmd := buildRootCommand(cfg)
err := cmd.Run(context.Background(), []string{"goyco", "post", "list"})
if err != nil {
t.Errorf("unexpected error for post list: %v", err)

View File

@@ -1,527 +0,0 @@
package commands
import (
"context"
"fmt"
"math/rand"
"runtime"
"sync"
"time"
"goyco/internal/database"
"goyco/internal/repositories"
)
type ParallelProcessor struct {
maxWorkers int
timeout time.Duration
passwordHash string
randSource *rand.Rand
randMu sync.Mutex
}
func NewParallelProcessor() *ParallelProcessor {
maxWorkers := max(min(runtime.NumCPU(), 8), 2)
seed := time.Now().UnixNano()
return &ParallelProcessor{
maxWorkers: maxWorkers,
timeout: 60 * time.Second,
randSource: rand.New(rand.NewSource(seed)),
}
}
func (p *ParallelProcessor) SetPasswordHash(hash string) {
p.passwordHash = hash
}
type indexedResult[T any] struct {
value T
index int
}
func processInParallel[T any](
ctx context.Context,
maxWorkers int,
count int,
processor func(index int) (T, error),
errorPrefix string,
progress *ProgressIndicator,
) ([]T, error) {
results := make(chan indexedResult[T], count)
errors := make(chan error, count)
semaphore := make(chan struct{}, maxWorkers)
var wg sync.WaitGroup
for i := range count {
wg.Add(1)
go func(index int) {
defer wg.Done()
select {
case semaphore <- struct{}{}:
case <-ctx.Done():
errors <- ctx.Err()
return
}
defer func() { <-semaphore }()
value, err := processor(index + 1)
if err != nil {
errors <- fmt.Errorf("%s %d: %w", errorPrefix, index+1, err)
return
}
results <- indexedResult[T]{value: value, index: index}
}(i)
}
go func() {
wg.Wait()
close(results)
close(errors)
}()
items := make([]T, count)
completed := 0
firstError := make(chan error, 1)
go func() {
for err := range errors {
if err != nil {
select {
case firstError <- err:
default:
}
return
}
}
}()
for completed < count {
select {
case result, ok := <-results:
if !ok {
return items, nil
}
items[result.index] = result.value
completed++
if progress != nil {
progress.Update(completed)
}
case err := <-firstError:
return nil, err
case <-ctx.Done():
return nil, fmt.Errorf("timeout: %w", ctx.Err())
}
}
return items, nil
}
func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepository, count int, progress *ProgressIndicator) ([]database.User, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
return processInParallel(ctx, p.maxWorkers, count,
func(index int) (database.User, error) {
return p.createSingleUser(userRepo, index)
},
"create user",
progress,
)
}
func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepository, authorID uint, count int, progress *ProgressIndicator) ([]database.Post, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
return processInParallel(ctx, p.maxWorkers, count,
func(index int) (database.Post, error) {
return p.createSinglePost(postRepo, authorID, index)
},
"create post",
progress,
)
}
func processItemsInParallel[T any, R any](
ctx context.Context,
maxWorkers int,
items []T,
processor func(index int, item T) (R, error),
errorPrefix string,
aggregator func(accumulator R, value R) R,
initialValue R,
progress *ProgressIndicator,
) (R, error) {
count := len(items)
results := make(chan indexedResult[R], count)
errors := make(chan error, count)
semaphore := make(chan struct{}, maxWorkers)
var wg sync.WaitGroup
for i, item := range items {
wg.Add(1)
go func(index int, item T) {
defer wg.Done()
select {
case semaphore <- struct{}{}:
case <-ctx.Done():
errors <- ctx.Err()
return
}
defer func() { <-semaphore }()
value, err := processor(index, item)
if err != nil {
errors <- fmt.Errorf("%s %d: %w", errorPrefix, index+1, err)
return
}
results <- indexedResult[R]{value: value, index: index}
}(i, item)
}
go func() {
wg.Wait()
close(results)
close(errors)
}()
accumulator := initialValue
completed := 0
firstError := make(chan error, 1)
go func() {
for err := range errors {
if err != nil {
select {
case firstError <- err:
default:
}
return
}
}
}()
for completed < count {
select {
case result, ok := <-results:
if !ok {
return accumulator, nil
}
accumulator = aggregator(accumulator, result.value)
completed++
if progress != nil {
progress.Update(completed)
}
case err := <-firstError:
return initialValue, err
case <-ctx.Done():
return initialValue, fmt.Errorf("timeout: %w", ctx.Err())
}
}
return accumulator, nil
}
func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post, avgVotesPerPost int, progress *ProgressIndicator) (int, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
return processItemsInParallel(ctx, p.maxWorkers, posts,
func(index int, post database.Post) (int, error) {
return p.createVotesForPost(voteRepo, users, post, avgVotesPerPost)
},
"create votes for post",
func(acc, val int) int { return acc + val },
0,
progress,
)
}
func processItemsInParallelNoResult[T any](
ctx context.Context,
maxWorkers int,
items []T,
processor func(index int, item T) error,
errorFormatter func(index int, item T, err error) error,
progress *ProgressIndicator,
) error {
count := len(items)
errors := make(chan error, count)
completions := make(chan struct{}, count)
semaphore := make(chan struct{}, maxWorkers)
var wg sync.WaitGroup
for i, item := range items {
wg.Add(1)
go func(index int, item T) {
defer wg.Done()
select {
case semaphore <- struct{}{}:
case <-ctx.Done():
errors <- ctx.Err()
return
}
defer func() { <-semaphore }()
err := processor(index, item)
if err != nil {
if errorFormatter != nil {
errors <- errorFormatter(index, item, err)
} else {
errors <- fmt.Errorf("process item %d: %w", index+1, err)
}
return
}
completions <- struct{}{}
}(i, item)
}
go func() {
wg.Wait()
close(errors)
close(completions)
}()
completed := 0
firstError := make(chan error, 1)
go func() {
for err := range errors {
if err != nil {
select {
case firstError <- err:
default:
}
return
}
}
}()
for completed < count {
select {
case _, ok := <-completions:
if !ok {
return nil
}
completed++
if progress != nil {
progress.Update(completed)
}
case err := <-firstError:
return err
case <-ctx.Done():
return fmt.Errorf("timeout: %w", ctx.Err())
}
}
return nil
}
func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, progress *ProgressIndicator) error {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
return processItemsInParallelNoResult(ctx, p.maxWorkers, posts,
func(index int, post database.Post) error {
return p.updateSinglePostScore(postRepo, voteRepo, post)
},
func(index int, post database.Post, err error) error {
return fmt.Errorf("update post %d scores: %w", post.ID, err)
},
progress,
)
}
func (p *ParallelProcessor) generateRandomIdentifier() string {
const length = 12
const chars = "abcdefghijklmnopqrstuvwxyz0123456789"
identifier := make([]byte, length)
p.randMu.Lock()
for i := range identifier {
identifier[i] = chars[p.randSource.Intn(len(chars))]
}
p.randMu.Unlock()
return string(identifier)
}
func (p *ParallelProcessor) createSingleUser(userRepo repositories.UserRepository, index int) (database.User, error) {
randomID := p.generateRandomIdentifier()
username := fmt.Sprintf("user_%s", randomID)
email := fmt.Sprintf("user_%s@goyco.local", randomID)
const maxRetries = 10
for range maxRetries {
user := &database.User{
Username: username,
Email: email,
Password: p.passwordHash,
EmailVerified: true,
}
if err := userRepo.Create(user); err != nil {
randomID = p.generateRandomIdentifier()
username = fmt.Sprintf("user_%s", randomID)
email = fmt.Sprintf("user_%s@goyco.local", randomID)
continue
}
return *user, nil
}
return database.User{}, fmt.Errorf("failed to create user after %d attempts", maxRetries)
}
func (p *ParallelProcessor) createSinglePost(postRepo repositories.PostRepository, authorID uint, index int) (database.Post, error) {
sampleTitles := []string{
"Amazing JavaScript Framework",
"Python Best Practices",
"Go Performance Tips",
"Database Optimization",
"Web Security Guide",
"Machine Learning Basics",
"Cloud Architecture",
"DevOps Automation",
"API Design Patterns",
"Frontend Optimization",
"Backend Scaling",
"Container Orchestration",
"Microservices Architecture",
"Testing Strategies",
"Code Review Process",
"Version Control Best Practices",
"Continuous Integration",
"Monitoring and Alerting",
"Error Handling Patterns",
"Data Structures Explained",
}
sampleDomains := []string{
"example.com",
"techblog.org",
"devguide.net",
"programming.io",
"codeexamples.com",
"tutorialhub.org",
"bestpractices.dev",
"learnprogramming.net",
"codingtips.org",
"softwareengineering.com",
}
title := sampleTitles[index%len(sampleTitles)]
if index >= len(sampleTitles) {
title = fmt.Sprintf("%s - Part %d", title, (index/len(sampleTitles))+1)
}
domain := sampleDomains[index%len(sampleDomains)]
randomID := p.generateRandomIdentifier()
path := fmt.Sprintf("/article/%s", randomID)
url := fmt.Sprintf("https://%s%s", domain, path)
content := fmt.Sprintf("Autogenerated seed post #%d\n\nThis is sample content for testing purposes. The post discusses %s and provides valuable insights.", index, title)
const maxRetries = 10
for range maxRetries {
post := &database.Post{
Title: title,
URL: url,
Content: content,
AuthorID: &authorID,
UpVotes: 0,
DownVotes: 0,
Score: 0,
}
if err := postRepo.Create(post); err != nil {
randomID = p.generateRandomIdentifier()
path = fmt.Sprintf("/article/%s", randomID)
url = fmt.Sprintf("https://%s%s", domain, path)
continue
}
return *post, nil
}
return database.Post{}, fmt.Errorf("failed to create post after %d attempts", maxRetries)
}
func (p *ParallelProcessor) createVotesForPost(voteRepo repositories.VoteRepository, users []database.User, post database.Post, avgVotesPerPost int) (int, error) {
p.randMu.Lock()
numVotes := p.randSource.Intn(avgVotesPerPost*2 + 1)
p.randMu.Unlock()
if numVotes == 0 && avgVotesPerPost > 0 {
p.randMu.Lock()
if p.randSource.Intn(5) > 0 {
numVotes = 1
}
p.randMu.Unlock()
}
totalVotes := 0
usedUsers := make(map[uint]bool)
for i := 0; i < numVotes && len(usedUsers) < len(users); i++ {
p.randMu.Lock()
userIdx := p.randSource.Intn(len(users))
p.randMu.Unlock()
user := users[userIdx]
if usedUsers[user.ID] {
continue
}
usedUsers[user.ID] = true
p.randMu.Lock()
voteTypeInt := p.randSource.Intn(10)
p.randMu.Unlock()
var voteType database.VoteType
if voteTypeInt < 7 {
voteType = database.VoteUp
} else {
voteType = database.VoteDown
}
vote := &database.Vote{
UserID: &user.ID,
PostID: post.ID,
Type: voteType,
}
if err := voteRepo.CreateOrUpdate(vote); err != nil {
return totalVotes, fmt.Errorf("create or update vote: %w", err)
}
totalVotes++
}
return totalVotes, nil
}
func (p *ParallelProcessor) updateSinglePostScore(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, post database.Post) error {
upVotes, downVotes, err := getVoteCounts(voteRepo, post.ID)
if err != nil {
return fmt.Errorf("get vote counts: %w", err)
}
post.UpVotes = upVotes
post.DownVotes = downVotes
post.Score = upVotes - downVotes
if err := postRepo.Update(&post); err != nil {
return fmt.Errorf("update post: %w", err)
}
return nil
}

View File

@@ -1,145 +0,0 @@
package commands_test
import (
"errors"
"sync"
"testing"
"goyco/cmd/goyco/commands"
"goyco/internal/database"
"goyco/internal/repositories"
"goyco/internal/testutils"
"golang.org/x/crypto/bcrypt"
)
func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
const successCount = 4
tests := []struct {
name string
count int
repoFactory func() repositories.UserRepository
progress *commands.ProgressIndicator
validate func(t *testing.T, got []database.User)
wantErr bool
}{
{
name: "creates users with required fields",
count: successCount,
repoFactory: func() repositories.UserRepository {
base := testutils.NewMockUserRepository()
return newFakeUserRepo(base, 0, nil)
},
progress: nil,
validate: func(t *testing.T, got []database.User) {
t.Helper()
if len(got) != successCount {
t.Fatalf("expected %d users, got %d", successCount, len(got))
}
usernames := make(map[string]bool)
for i, user := range got {
if user.Username == "" {
t.Errorf("user %d expected non-empty username", i)
}
if len(user.Username) < 6 || user.Username[:5] != "user_" {
t.Errorf("user %d username should start with 'user_', got %q", i, user.Username)
}
if usernames[user.Username] {
t.Errorf("user %d duplicate username: %q", i, user.Username)
}
usernames[user.Username] = true
if user.Email == "" {
t.Errorf("user %d expected non-empty email", i)
}
if len(user.Email) < 20 || user.Email[:5] != "user_" || user.Email[len(user.Email)-12:] != "@goyco.local" {
t.Errorf("user %d email should match pattern 'user_*@goyco.local', got %q", i, user.Email)
}
if !user.EmailVerified {
t.Errorf("user %d expected EmailVerified to be true", i)
}
if user.ID == 0 {
t.Errorf("user %d expected non-zero ID", i)
}
if user.Password == "" {
t.Errorf("user %d expected hashed password to be populated", i)
}
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte("password123")); err != nil {
t.Errorf("user %d password not hashed correctly: %v", i, err)
}
if user.CreatedAt.IsZero() {
t.Errorf("user %d expected CreatedAt to be set", i)
}
if user.UpdatedAt.IsZero() {
t.Errorf("user %d expected UpdatedAt to be set", i)
}
}
},
},
{
name: "returns error when repository create fails",
count: 3,
repoFactory: func() repositories.UserRepository {
base := testutils.NewMockUserRepository()
return newFakeUserRepo(base, 1, errors.New("create failure"))
},
progress: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
repo := tt.repoFactory()
p := commands.NewParallelProcessor()
passwordHash, err := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
if err != nil {
t.Fatalf("failed to generate password hash: %v", err)
}
p.SetPasswordHash(string(passwordHash))
got, gotErr := p.CreateUsersInParallel(repo, tt.count, tt.progress)
if gotErr != nil {
if !tt.wantErr {
t.Errorf("CreateUsersInParallel() failed: %v", gotErr)
}
if got != nil {
t.Error("expected nil result when error occurs")
}
return
}
if tt.wantErr {
t.Fatal("CreateUsersInParallel() succeeded unexpectedly")
}
if tt.validate != nil {
tt.validate(t, got)
}
})
}
}
type fakeUserRepo struct {
repositories.UserRepository
mu sync.Mutex
failAt int
err error
calls int
}
func newFakeUserRepo(base repositories.UserRepository, failAt int, err error) *fakeUserRepo {
return &fakeUserRepo{
UserRepository: base,
failAt: failAt,
err: err,
}
}
func (r *fakeUserRepo) Create(user *database.User) error {
r.mu.Lock()
defer r.mu.Unlock()
r.calls++
if r.failAt > 0 && r.calls >= r.failAt {
return r.err
}
return r.UserRepository.Create(user)
}

View File

@@ -56,16 +56,16 @@ func newProgressIndicatorWithClock(total int, description string, c clock) *Prog
func (p *ProgressIndicator) Update(current int) {
p.mu.Lock()
defer p.mu.Unlock()
p.current = current
now := p.clock.Now()
if now.Sub(p.lastUpdate) < 100*time.Millisecond {
p.mu.Unlock()
return
}
p.lastUpdate = now
p.mu.Unlock()
p.display()
}

View File

@@ -44,15 +44,14 @@ func captureOutput(fn func()) string {
r, w, _ := os.Pipe()
os.Stdout = w
defer func() {
fn()
_ = w.Close()
os.Stdout = old
}()
fn()
var buf bytes.Buffer
_, _ = io.Copy(&buf, r)
_ = r.Close()
return buf.String()
}

View File

@@ -6,29 +6,17 @@ import (
"fmt"
"math/rand"
"os"
"sync"
"strings"
"time"
"github.com/lib/pq"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"goyco/internal/config"
"goyco/internal/database"
"goyco/internal/repositories"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
var (
seedRandSource *rand.Rand
seedRandOnce sync.Once
)
func initSeedRand() {
seedRandOnce.Do(func() {
seed := time.Now().UnixNano()
seedRandSource = rand.New(rand.NewSource(seed))
})
}
func HandleSeedCommand(cfg *config.Config, name string, args []string) error {
fs := newFlagSet(name, printSeedUsage)
if err := parseCommand(fs, args, name); err != nil {
@@ -39,11 +27,13 @@ func HandleSeedCommand(cfg *config.Config, name string, args []string) error {
}
return withDatabase(cfg, func(db *gorm.DB) error {
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
voteRepo := repositories.NewVoteRepository(db)
return db.Transaction(func(tx *gorm.DB) error {
userRepo := repositories.NewUserRepository(db).WithTx(tx)
postRepo := repositories.NewPostRepository(db).WithTx(tx)
voteRepo := repositories.NewVoteRepository(db).WithTx(tx)
return runSeedCommand(userRepo, postRepo, voteRepo, fs.Args())
})
})
}
func runSeedCommand(userRepo repositories.UserRepository, postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, args []string) error {
@@ -72,45 +62,37 @@ func printSeedUsage() {
fmt.Fprintln(os.Stderr, " --votes-per-post: average votes per post (default: 15)")
}
func clampFlagValue(value *int, min int, name string) {
if *value < min {
if !IsJSONOutput() {
fmt.Fprintf(os.Stderr, "Warning: --%s value %d is too low, clamping to %d\n", name, *value, min)
}
*value = min
}
}
func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, args []string) error {
fs := flag.NewFlagSet("seed database", flag.ContinueOnError)
numPosts := fs.Int("posts", 40, "number of posts to create")
numUsers := fs.Int("users", 5, "number of additional users to create")
votesPerPost := fs.Int("votes-per-post", 15, "average votes per post")
fs.SetOutput(os.Stderr)
fs.Usage = func() {
fmt.Fprintln(os.Stderr, "Usage: goyco seed database [--posts <n>] [--users <n>] [--votes-per-post <n>]")
fmt.Fprintln(os.Stderr, "\nOptions:")
fs.PrintDefaults()
}
if err := fs.Parse(args); err != nil {
if errors.Is(err, flag.ErrHelp) {
return nil
}
return err
}
originalUsers := *numUsers
originalPosts := *numPosts
originalVotesPerPost := *votesPerPost
if *numUsers < 0 {
if !IsJSONOutput() {
fmt.Fprintf(os.Stderr, "Warning: --users value %d is negative, clamping to 0\n", *numUsers)
}
*numUsers = 0
}
if *numPosts <= 0 {
if !IsJSONOutput() {
fmt.Fprintf(os.Stderr, "Warning: --posts value %d is too low, clamping to 1\n", *numPosts)
}
*numPosts = 1
}
if *votesPerPost < 0 {
if !IsJSONOutput() {
fmt.Fprintf(os.Stderr, "Warning: --votes-per-post value %d is negative, clamping to 0\n", *votesPerPost)
}
*votesPerPost = 0
}
if !IsJSONOutput() && (originalUsers != *numUsers || originalPosts != *numPosts || originalVotesPerPost != *votesPerPost) {
fmt.Fprintf(os.Stderr, "Using clamped values: --users=%d --posts=%d --votes-per-post=%d\n", *numUsers, *numPosts, *votesPerPost)
}
clampFlagValue(numUsers, 0, "users")
clampFlagValue(numPosts, 1, "posts")
clampFlagValue(votesPerPost, 0, "votes-per-post")
if !IsJSONOutput() {
fmt.Println("Starting database seeding...")
@@ -129,71 +111,35 @@ func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.Po
return fmt.Errorf("precompute user password hash: %w", err)
}
spinner := NewSpinner("Creating seed user")
if !IsJSONOutput() {
spinner.Spin()
}
seedUser, err := ensureSeedUser(userRepo, string(seedPasswordHash))
if err != nil {
if !IsJSONOutput() {
spinner.Complete()
}
return fmt.Errorf("ensure seed user: %w", err)
}
if !IsJSONOutput() {
spinner.Complete()
fmt.Printf("Seed user ready: ID=%d Username=%s\n", seedUser.ID, seedUser.Username)
}
processor := NewParallelProcessor()
processor.SetPasswordHash(string(userPasswordHash))
generator := newSeedGenerator(string(userPasswordHash))
allUsers := []database.User{*seedUser}
var progress *ProgressIndicator
if !IsJSONOutput() && *numUsers > 0 {
progress = NewProgressIndicator(*numUsers, "Creating users (parallel)")
}
users, err := processor.CreateUsersInParallel(userRepo, *numUsers, progress)
users, err := createUsers(generator, userRepo, *numUsers, "Creating users")
if err != nil {
return fmt.Errorf("create random users: %w", err)
}
if !IsJSONOutput() && progress != nil {
progress.Complete()
return err
}
allUsers = append(allUsers, users...)
allUsers := append([]database.User{*seedUser}, users...)
if !IsJSONOutput() && *numPosts > 0 {
progress = NewProgressIndicator(*numPosts, "Creating posts (parallel)")
}
posts, err := processor.CreatePostsInParallel(postRepo, seedUser.ID, *numPosts, progress)
posts, err := createPosts(generator, postRepo, seedUser.ID, *numPosts, "Creating posts")
if err != nil {
return fmt.Errorf("create random posts: %w", err)
}
if !IsJSONOutput() && progress != nil {
progress.Complete()
return err
}
if !IsJSONOutput() && len(posts) > 0 {
progress = NewProgressIndicator(len(posts), "Creating votes (parallel)")
}
votes, err := processor.CreateVotesInParallel(voteRepo, allUsers, posts, *votesPerPost, progress)
votes, err := createVotes(generator, voteRepo, allUsers, posts, *votesPerPost, "Creating votes")
if err != nil {
return fmt.Errorf("create random votes: %w", err)
}
if !IsJSONOutput() && progress != nil {
progress.Complete()
return err
}
if !IsJSONOutput() && len(posts) > 0 {
progress = NewProgressIndicator(len(posts), "Updating scores (parallel)")
}
err = processor.UpdatePostScoresInParallel(postRepo, voteRepo, posts, progress)
if err != nil {
return fmt.Errorf("update post scores: %w", err)
}
if !IsJSONOutput() && progress != nil {
progress.Complete()
if err := updateScores(generator, postRepo, voteRepo, posts, "Updating scores"); err != nil {
return err
}
if err := validateSeedConsistency(voteRepo, allUsers, posts); err != nil {
@@ -225,11 +171,15 @@ const (
)
func ensureSeedUser(userRepo repositories.UserRepository, passwordHash string) (*database.User, error) {
if user, err := userRepo.GetByUsername(seedUsername); err == nil {
user, err := userRepo.GetByUsername(seedUsername)
if err == nil {
return user, nil
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("failed to check if seed user exists: %w", err)
}
user := &database.User{
user = &database.User{
Username: seedUsername,
Email: seedEmail,
Password: passwordHash,
@@ -243,10 +193,6 @@ func ensureSeedUser(userRepo repositories.UserRepository, passwordHash string) (
return user, nil
}
func getVoteCounts(voteRepo repositories.VoteRepository, postID uint) (int, int, error) {
return voteRepo.GetVoteCountsByPostID(postID)
}
func validateSeedConsistency(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post) error {
userIDSet := make(map[uint]struct{}, len(users))
for _, user := range users {
@@ -259,8 +205,11 @@ func validateSeedConsistency(voteRepo repositories.VoteRepository, users []datab
}
for _, post := range posts {
if err := validatePost(post, userIDSet); err != nil {
return err
if post.AuthorID == nil {
return fmt.Errorf("post %d has no author ID", post.ID)
}
if _, exists := userIDSet[*post.AuthorID]; !exists {
return fmt.Errorf("post %d references non-existent author ID %d", post.ID, *post.AuthorID)
}
votes, err := voteRepo.GetByPostID(post.ID)
@@ -268,46 +217,293 @@ func validateSeedConsistency(voteRepo repositories.VoteRepository, users []datab
return fmt.Errorf("failed to retrieve votes for post %d: %w", post.ID, err)
}
if err := validateVotesForPost(post.ID, votes, userIDSet, postIDSet); err != nil {
return err
}
}
return nil
}
func validatePost(post database.Post, userIDSet map[uint]struct{}) error {
if post.AuthorID == nil {
return fmt.Errorf("post %d has no author ID", post.ID)
}
if _, exists := userIDSet[*post.AuthorID]; !exists {
return fmt.Errorf("post %d references non-existent author ID %d", post.ID, *post.AuthorID)
}
return nil
}
func validateVotesForPost(postID uint, votes []database.Vote, userIDSet map[uint]struct{}, postIDSet map[uint]struct{}) error {
for _, vote := range votes {
if vote.PostID != postID {
return fmt.Errorf("vote %d references post ID %d but was retrieved for post %d", vote.ID, vote.PostID, postID)
if vote.PostID != post.ID {
return fmt.Errorf("vote %d references post ID %d but was retrieved for post %d", vote.ID, vote.PostID, post.ID)
}
if _, exists := postIDSet[vote.PostID]; !exists {
return fmt.Errorf("vote %d references non-existent post ID %d", vote.ID, vote.PostID)
}
if vote.UserID != nil {
if _, exists := userIDSet[*vote.UserID]; !exists {
return fmt.Errorf("vote %d references non-existent user ID %d", vote.ID, *vote.UserID)
}
}
if vote.Type != database.VoteUp && vote.Type != database.VoteDown {
return fmt.Errorf("vote %d has invalid type %q", vote.ID, vote.Type)
}
}
}
return nil
}
type seedGenerator struct {
passwordHash string
randSource *rand.Rand
}
func newSeedGenerator(passwordHash string) *seedGenerator {
seed := time.Now().UnixNano()
return &seedGenerator{
passwordHash: passwordHash,
randSource: rand.New(rand.NewSource(seed)),
}
}
func isRetryableError(err error, keywords ...string) bool {
if err == nil {
return false
}
errMsg := strings.ToLower(err.Error())
if errors.Is(err, gorm.ErrDuplicatedKey) {
for _, keyword := range keywords {
if strings.Contains(errMsg, keyword) {
return true
}
}
return false
}
var pqErr *pq.Error
if errors.As(err, &pqErr) && pqErr.Code == "23505" {
constraintLower := strings.ToLower(pqErr.Constraint)
errMsgLower := strings.ToLower(pqErr.Message)
for _, keyword := range keywords {
if strings.Contains(constraintLower, keyword) || strings.Contains(errMsgLower, keyword) {
return true
}
}
return false
}
if strings.Contains(errMsg, "duplicate") {
for _, keyword := range keywords {
if strings.Contains(errMsg, keyword) {
return true
}
}
}
return false
}
func createUsers(g *seedGenerator, userRepo repositories.UserRepository, count int, desc string) ([]database.User, error) {
if count == 0 {
return nil, nil
}
progress := maybeProgress(count, desc)
users := make([]database.User, 0, count)
for i := 0; i < count; i++ {
user, err := g.createSingleUser(userRepo, i+1)
if err != nil {
return nil, fmt.Errorf("create random user: %w", err)
}
users = append(users, user)
if progress != nil {
progress.Increment()
}
}
if progress != nil {
progress.Complete()
}
return users, nil
}
func createPosts(g *seedGenerator, postRepo repositories.PostRepository, authorID uint, count int, desc string) ([]database.Post, error) {
if count == 0 {
return nil, nil
}
progress := maybeProgress(count, desc)
posts := make([]database.Post, 0, count)
for i := 0; i < count; i++ {
post, err := g.createSinglePost(postRepo, authorID, i+1)
if err != nil {
return nil, fmt.Errorf("create random post: %w", err)
}
posts = append(posts, post)
if progress != nil {
progress.Increment()
}
}
if progress != nil {
progress.Complete()
}
return posts, nil
}
func createVotes(g *seedGenerator, voteRepo repositories.VoteRepository, users []database.User, posts []database.Post, avgVotesPerPost int, desc string) (int, error) {
if len(posts) == 0 {
return 0, nil
}
progress := maybeProgress(len(posts), desc)
votes := 0
for _, post := range posts {
count, err := g.createVotesForPost(voteRepo, users, post, avgVotesPerPost)
if err != nil {
return 0, fmt.Errorf("create random votes for post %d: %w", post.ID, err)
}
votes += count
if progress != nil {
progress.Increment()
}
}
if progress != nil {
progress.Complete()
}
return votes, nil
}
func updateScores(g *seedGenerator, postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, desc string) error {
if len(posts) == 0 {
return nil
}
progress := maybeProgress(len(posts), desc)
for _, post := range posts {
if err := g.updateSinglePostScore(postRepo, voteRepo, post); err != nil {
return fmt.Errorf("update post scores: %w", err)
}
if progress != nil {
progress.Increment()
}
}
if progress != nil {
progress.Complete()
}
return nil
}
func maybeProgress(count int, desc string) *ProgressIndicator {
if !IsJSONOutput() && count > 0 {
return NewProgressIndicator(count, desc)
}
return nil
}
func (g *seedGenerator) generateRandomIdentifier() string {
const length = 12
const chars = "abcdefghijklmnopqrstuvwxyz0123456789"
identifier := make([]byte, length)
for i := range identifier {
identifier[i] = chars[g.randSource.Intn(len(chars))]
}
return string(identifier)
}
func (g *seedGenerator) createSingleUser(userRepo repositories.UserRepository, index int) (database.User, error) {
const maxRetries = 10
var lastErr error
for attempt := range maxRetries {
randomID := g.generateRandomIdentifier()
user := &database.User{
Username: fmt.Sprintf("user_%s", randomID),
Email: fmt.Sprintf("user_%s@goyco.local", randomID),
Password: g.passwordHash,
EmailVerified: true,
}
if err := userRepo.Create(user); err != nil {
lastErr = err
if !isRetryableError(err, "username", "email", "users_username_key", "users_email_key") {
return database.User{}, fmt.Errorf("failed to create user (attempt %d/%d): %w", attempt+1, maxRetries, err)
}
continue
}
return *user, nil
}
return database.User{}, fmt.Errorf("failed to create user after %d attempts: %w", maxRetries, lastErr)
}
var (
sampleTitles = []string{"Amazing JavaScript Framework", "Python Best Practices", "Go Performance Tips", "Database Optimization", "Web Security Guide", "Machine Learning Basics", "Cloud Architecture", "DevOps Automation", "API Design Patterns", "Frontend Optimization", "Backend Scaling", "Container Orchestration", "Microservices Architecture", "Testing Strategies", "Code Review Process", "Version Control Best Practices", "Continuous Integration", "Monitoring and Alerting", "Error Handling Patterns", "Data Structures Explained"}
sampleDomains = []string{"example.com", "techblog.org", "devguide.net", "programming.io", "codeexamples.com", "tutorialhub.org", "bestpractices.dev", "learnprogramming.net", "codingtips.org", "softwareengineering.com"}
)
func (g *seedGenerator) createSinglePost(postRepo repositories.PostRepository, authorID uint, index int) (database.Post, error) {
title := sampleTitles[index%len(sampleTitles)]
if index >= len(sampleTitles) {
title = fmt.Sprintf("%s - Part %d", title, (index/len(sampleTitles))+1)
}
domain := sampleDomains[index%len(sampleDomains)]
content := fmt.Sprintf("Autogenerated seed post #%d\n\nThis is sample content for testing purposes. The post discusses %s and provides valuable insights.", index, title)
const maxRetries = 10
var lastErr error
for attempt := range maxRetries {
randomID := g.generateRandomIdentifier()
post := &database.Post{
Title: title,
URL: fmt.Sprintf("https://%s/article/%s", domain, randomID),
Content: content,
AuthorID: &authorID,
UpVotes: 0,
DownVotes: 0,
Score: 0,
}
if err := postRepo.Create(post); err != nil {
lastErr = err
if !isRetryableError(err, "url", "posts_url_key") {
return database.Post{}, fmt.Errorf("failed to create post (attempt %d/%d): %w", attempt+1, maxRetries, err)
}
continue
}
return *post, nil
}
return database.Post{}, fmt.Errorf("failed to create post after %d attempts: %w", maxRetries, lastErr)
}
func (g *seedGenerator) createVotesForPost(voteRepo repositories.VoteRepository, users []database.User, post database.Post, avgVotesPerPost int) (int, error) {
numVotes := g.randSource.Intn(avgVotesPerPost*2 + 1)
if numVotes == 0 && avgVotesPerPost > 0 {
if g.randSource.Intn(5) > 0 {
numVotes = 1
}
}
totalVotes := 0
usedUsers := make(map[uint]bool)
for i := 0; i < numVotes && len(usedUsers) < len(users); i++ {
userIdx := g.randSource.Intn(len(users))
user := users[userIdx]
if usedUsers[user.ID] {
continue
}
usedUsers[user.ID] = true
voteTypeInt := g.randSource.Intn(10)
var voteType database.VoteType
if voteTypeInt < 7 {
voteType = database.VoteUp
} else {
voteType = database.VoteDown
}
vote := &database.Vote{
UserID: &user.ID,
PostID: post.ID,
Type: voteType,
}
if err := voteRepo.CreateOrUpdate(vote); err != nil {
return totalVotes, fmt.Errorf("create or update vote: %w", err)
}
totalVotes++
}
return totalVotes, nil
}
func (g *seedGenerator) updateSinglePostScore(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, post database.Post) error {
upVotes, downVotes, err := voteRepo.GetVoteCountsByPostID(post.ID)
if err != nil {
return fmt.Errorf("get vote counts: %w", err)
}
post.UpVotes = upVotes
post.DownVotes = downVotes
post.Score = upVotes - downVotes
return postRepo.Update(&post)
}

View File

@@ -2,8 +2,11 @@ package commands
import (
"fmt"
"math/rand"
"strings"
"sync"
"testing"
"time"
"goyco/internal/database"
"goyco/internal/repositories"
@@ -13,6 +16,20 @@ import (
"gorm.io/gorm"
)
var (
seedRandSource *rand.Rand
seedRandOnce sync.Once
)
const testPasswordHash = "test_password_hash"
func initSeedRand() {
seedRandOnce.Do(func() {
seed := time.Now().UnixNano()
seedRandSource = rand.New(rand.NewSource(seed))
})
}
func TestSeedCommand(t *testing.T) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
@@ -46,11 +63,11 @@ func TestSeedCommand(t *testing.T) {
seedUserCount := 0
var seedUser *database.User
regularUserCount := 0
for i := range users {
if users[i].Username == "seed_admin" {
for idx := range users {
if users[idx].Username == seedUsername {
seedUserCount++
seedUser = &users[i]
} else if strings.HasPrefix(users[i].Username, "user_") {
seedUser = &users[idx]
} else if strings.HasPrefix(users[idx].Username, "user_") {
regularUserCount++
}
}
@@ -63,12 +80,12 @@ func TestSeedCommand(t *testing.T) {
t.Fatal("Expected seed user to be created")
}
if seedUser.Username != "seed_admin" {
t.Errorf("Expected username to be 'seed_admin', got '%s'", seedUser.Username)
if seedUser.Username != seedUsername {
t.Errorf("Expected username to be %q, got '%s'", seedUsername, seedUser.Username)
}
if seedUser.Email != "seed_admin@goyco.local" {
t.Errorf("Expected email to be 'seed_admin@goyco.local', got '%s'", seedUser.Email)
if seedUser.Email != seedEmail {
t.Errorf("Expected email to be %q, got '%s'", seedEmail, seedUser.Email)
}
if !seedUser.EmailVerified {
@@ -88,20 +105,20 @@ func TestSeedCommand(t *testing.T) {
t.Errorf("Expected 5 posts, got %d", len(posts))
}
for i, post := range posts {
for idx, post := range posts {
if post.Title == "" {
t.Errorf("Post %d has empty title", i)
t.Errorf("Post %d has empty title", idx)
}
if post.URL == "" {
t.Errorf("Post %d has empty URL", i)
t.Errorf("Post %d has empty URL", idx)
}
if post.AuthorID == nil || *post.AuthorID != seedUser.ID {
t.Errorf("Post %d has wrong author ID: expected %d, got %v", i, seedUser.ID, post.AuthorID)
t.Errorf("Post %d has wrong author ID: expected %d, got %v", idx, seedUser.ID, post.AuthorID)
}
expectedScore := post.UpVotes - post.DownVotes
if post.Score != expectedScore {
t.Errorf("Post %d has incorrect score: expected %d, got %d", i, expectedScore, post.Score)
t.Errorf("Post %d has incorrect score: expected %d, got %d", idx, expectedScore, post.Score)
}
}
@@ -133,11 +150,12 @@ func TestSeedCommand(t *testing.T) {
}
func TestGenerateRandomPath(t *testing.T) {
const articlePathPrefix = "/article/"
initSeedRand()
pathLength := seedRandSource.Intn(20)
path := "/article/"
path := articlePathPrefix
for i := 0; i < pathLength+5; i++ {
for idx := 0; idx < pathLength+5; idx++ {
randomChar := seedRandSource.Intn(26)
path += string(rune('a' + randomChar))
}
@@ -152,13 +170,14 @@ func TestGenerateRandomPath(t *testing.T) {
initSeedRand()
secondPathLength := seedRandSource.Intn(20)
secondPath := "/article/"
for i := 0; i < secondPathLength+5; i++ {
var secondPath strings.Builder
secondPath.WriteString(articlePathPrefix)
for idx := 0; idx < secondPathLength+5; idx++ {
randomChar := seedRandSource.Intn(26)
secondPath += string(rune('a' + randomChar))
secondPath.WriteString(string(rune('a' + randomChar)))
}
if path == secondPath {
if path == secondPath.String() {
t.Error("Generated paths should be different")
}
}
@@ -271,6 +290,22 @@ func TestSeedDatabaseFlagParsing(t *testing.T) {
t.Errorf("zero votes-per-post should be valid, got error: %v", err)
}
})
t.Run("help flag returns no error", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--help"})
if err != nil {
t.Errorf("help flag should return no error, got: %v", err)
}
})
t.Run("short help flag returns no error", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"-h"})
if err != nil {
t.Errorf("short help flag should return no error, got: %v", err)
}
})
}
func TestSeedCommandIdempotency(t *testing.T) {
@@ -302,7 +337,7 @@ func TestSeedCommandIdempotency(t *testing.T) {
seedUserCount := 0
for _, user := range users {
if user.Username == "seed_admin" {
if user.Username == seedUsername {
seedUserCount++
}
}
@@ -338,10 +373,10 @@ func TestSeedCommandIdempotency(t *testing.T) {
})
t.Run("database remains consistent after multiple runs", func(t *testing.T) {
for i := range 2 {
for idx := range 2 {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--users", "0", "--posts", "1"})
if err != nil {
t.Fatalf("Seed run %d failed: %v", i+1, err)
t.Fatalf("Seed run %d failed: %v", idx+1, err)
}
}
@@ -386,9 +421,9 @@ func TestSeedCommandIdempotency(t *testing.T) {
}
func findSeedUser(users []database.User) *database.User {
for i := range users {
if users[i].Username == "seed_admin" {
return &users[i]
for idx := range users {
if users[idx].Username == seedUsername {
return &users[idx]
}
}
return nil
@@ -488,14 +523,14 @@ func TestEnsureSeedUser(t *testing.T) {
}
userRepo := repositories.NewUserRepository(db)
passwordHash := "test_password_hash"
passwordHash := testPasswordHash
firstUser, err := ensureSeedUser(userRepo, passwordHash)
if err != nil {
t.Fatalf("Failed to create seed user: %v", err)
}
if firstUser.Username != "seed_admin" || firstUser.Email != "seed_admin@goyco.local" || firstUser.Password != passwordHash || !firstUser.EmailVerified {
if firstUser.Username != seedUsername || firstUser.Email != seedEmail || firstUser.Password != passwordHash || !firstUser.EmailVerified {
t.Errorf("Invalid seed user: username=%s, email=%s, password matches=%v, emailVerified=%v",
firstUser.Username, firstUser.Email, firstUser.Password == passwordHash, firstUser.EmailVerified)
}
@@ -509,9 +544,9 @@ func TestEnsureSeedUser(t *testing.T) {
t.Errorf("Expected same user to be reused (ID %d), got different user (ID %d)", firstUser.ID, secondUser.ID)
}
for i := 0; i < 3; i++ {
for idx := range 3 {
if _, err := ensureSeedUser(userRepo, passwordHash); err != nil {
t.Fatalf("Call %d failed: %v", i+1, err)
t.Fatalf("Call %d failed: %v", idx+1, err)
}
}
@@ -522,7 +557,7 @@ func TestEnsureSeedUser(t *testing.T) {
seedUserCount := 0
for _, user := range users {
if user.Username == "seed_admin" {
if user.Username == seedUsername {
seedUserCount++
}
}
@@ -531,3 +566,25 @@ func TestEnsureSeedUser(t *testing.T) {
t.Errorf("Expected exactly 1 seed user, got %d", seedUserCount)
}
}
func TestEnsureSeedUser_HandlesDatabaseErrors(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
passwordHash := testPasswordHash
dbError := fmt.Errorf("database connection failed")
userRepo.SetGetByUsernameError(dbError)
_, err := ensureSeedUser(userRepo, passwordHash)
if err == nil {
t.Fatal("Expected error when GetByUsername returns database error")
}
if !strings.Contains(err.Error(), "failed to check if seed user exists") {
t.Errorf("Expected error message about checking seed user, got: %v", err)
}
if !strings.Contains(err.Error(), dbError.Error()) {
t.Errorf("Expected error to wrap original database error, got: %v", err)
}
}

View File

@@ -1,9 +1,8 @@
package main
import (
"flag"
"context"
"fmt"
"os"
"strings"
"testing"
"unicode/utf8"
@@ -12,6 +11,7 @@ import (
"goyco/internal/config"
"goyco/internal/testutils"
"github.com/urfave/cli/v3"
"gorm.io/gorm"
)
@@ -36,32 +36,15 @@ func FuzzCLIArgs(f *testing.F) {
if len(args) == 0 {
return
}
cmd := buildRootCommand(testutils.NewTestConfig())
for _, sub := range cmd.Commands {
sub.Action = func(context.Context, *cli.Command) error { return nil }
}
fs := flag.NewFlagSet("goyco", flag.ContinueOnError)
fs.SetOutput(os.Stderr)
fs.Usage = printRootUsage
showHelp := fs.Bool("help", false, "show this help message")
err := fs.Parse(args)
err := cmd.Run(context.Background(), append([]string{"goyco"}, args...))
if err != nil {
if !strings.Contains(err.Error(), "flag") && !strings.Contains(err.Error(), "help") {
t.Logf("Unexpected error format from flag parsing: %v", err)
}
}
if *showHelp && err != nil {
return
}
remaining := fs.Args()
if len(remaining) > 0 {
cmdName := remaining[0]
if len(cmdName) == 0 {
t.Fatal("Command name cannot be empty")
}
if !isValidUTF8(cmdName) {
t.Fatal("Command name must be valid UTF-8")
if !strings.Contains(err.Error(), "flag") && !strings.Contains(err.Error(), "unknown command") {
t.Logf("Unexpected error format from command parsing: %v", err)
}
}
})
@@ -96,12 +79,6 @@ func FuzzCommandDispatch(f *testing.F) {
})
defer commands.SetDBConnector(nil)
daemonCommands := map[string]bool{
"start": true,
"stop": true,
"status": true,
}
f.Add("run")
f.Add("help")
f.Add("user")
@@ -121,18 +98,20 @@ func FuzzCommandDispatch(f *testing.F) {
return
}
cmd := buildRootCommand(cfg)
for _, sub := range cmd.Commands {
sub.Action = func(context.Context, *cli.Command) error { return nil }
}
cmdName := parts[0]
args := parts[1:]
if daemonCommands[cmdName] {
return
}
err := dispatchCommand(cfg, cmdName, args)
err := cmd.Run(context.Background(), append([]string{"goyco"}, append([]string{cmdName}, args...)...))
knownCommands := map[string]bool{
"run": true, "user": true, "post": true, "prune": true, "migrate": true,
"migrations": true, "seed": true, "help": true, "-h": true, "--help": true,
"start": true, "stop": true, "status": true,
}
if knownCommands[cmdName] {

View File

@@ -1,5 +1,5 @@
// @title Goyco API
// @version 0.1.0
// @version 0.1.1
// @description Goyco is a Y Combinator-style news aggregation platform API.
// @contact.name Goyco Team
// @contact.email sandro@cazzaniga.fr
@@ -12,8 +12,8 @@
package main
import (
"context"
"errors"
"flag"
"fmt"
"log"
"os"
@@ -55,7 +55,7 @@ func run(args []string) error {
docs.SwaggerInfo.Title = fmt.Sprintf("%s API", cfg.App.Title)
docs.SwaggerInfo.Description = "Y Combinator-style news board API."
docs.SwaggerInfo.Version = version.Version
docs.SwaggerInfo.Version = version.GetVersion()
docs.SwaggerInfo.BasePath = "/api"
docs.SwaggerInfo.Host = fmt.Sprintf("%s:%s", cfg.Server.Host, cfg.Server.Port)
docs.SwaggerInfo.Schemes = []string{"http"}
@@ -63,62 +63,9 @@ func run(args []string) error {
docs.SwaggerInfo.Schemes = append(docs.SwaggerInfo.Schemes, "https")
}
rootFS := flag.NewFlagSet("goyco", flag.ContinueOnError)
rootFS.SetOutput(os.Stderr)
rootFS.Usage = printRootUsage
showHelp := rootFS.Bool("help", false, "show this help message")
jsonOutput := rootFS.Bool("json", cfg.CLI.JSONOutputDefault, "output results in JSON format")
if err := rootFS.Parse(args); err != nil {
if errors.Is(err, flag.ErrHelp) {
return nil
}
return fmt.Errorf("failed to parse arguments: %w", err)
}
if *showHelp {
printRootUsage()
return nil
}
commands.SetJSONOutput(*jsonOutput)
remaining := rootFS.Args()
if len(remaining) == 0 {
printRootUsage()
return nil
}
return dispatchCommand(cfg, remaining[0], remaining[1:])
}
func dispatchCommand(cfg *config.Config, name string, args []string) error {
switch name {
case "run":
return handleRunCommand(cfg, args)
case "start":
return commands.HandleStartCommand(cfg, args)
case "stop":
return commands.HandleStopCommand(cfg, args)
case "status":
return commands.HandleStatusCommand(cfg, name, args)
case "user":
return commands.HandleUserCommand(cfg, name, args)
case "post":
return commands.HandlePostCommand(cfg, name, args)
case "prune":
return commands.HandlePruneCommand(cfg, name, args)
case "migrate", "migrations":
return commands.HandleMigrateCommand(cfg, name, args)
case "seed":
return commands.HandleSeedCommand(cfg, name, args)
case "help", "-h", "--help":
printRootUsage()
return nil
default:
printRootUsage()
return fmt.Errorf("unknown command: %s", name)
}
root := buildRootCommand(cfg)
runArgs := append([]string{os.Args[0]}, args...)
return root.Run(context.Background(), runArgs)
}
func handleRunCommand(cfg *config.Config, args []string) error {

View File

@@ -3,14 +3,13 @@ package main
import (
"context"
"crypto/tls"
"errors"
"flag"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"goyco/cmd/goyco/commands"
"goyco/internal/config"
"goyco/internal/database"
"goyco/internal/handlers"
@@ -76,6 +75,10 @@ func TestServerConfigurationFromConfig(t *testing.T) {
IdleTimeout: cfg.Server.IdleTimeout,
MaxHeaderBytes: cfg.Server.MaxHeaderBytes,
}
expectedAddr := cfg.Server.Host + ":" + cfg.Server.Port
if srv.Addr != expectedAddr {
t.Errorf("Expected server address to be %q, got %q", expectedAddr, srv.Addr)
}
if srv.ReadTimeout != 30*time.Second {
t.Errorf("Expected ReadTimeout to be 30s, got %v", srv.ReadTimeout)
@@ -172,6 +175,10 @@ func TestTLSWiringFromConfig(t *testing.T) {
t.Errorf("Expected server address to be %q, got %q", expectedAddr, srv.Addr)
}
if srv.ReadHeaderTimeout != 5*time.Second {
t.Errorf("Expected ReadHeaderTimeout to be 5s, got %v", srv.ReadHeaderTimeout)
}
if cfg.Server.EnableTLS {
srv.TLSConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
@@ -267,36 +274,37 @@ func TestConfigLoadingInCLI(t *testing.T) {
}
func TestFlagParsingInCLI(t *testing.T) {
originalArgs := os.Args
defer func() {
os.Args = originalArgs
}()
t.Run("help flag", func(t *testing.T) {
os.Args = []string{"goyco", "--help"}
fs := flag.NewFlagSet("goyco", flag.ContinueOnError)
fs.SetOutput(os.Stderr)
showHelp := fs.Bool("help", false, "show help")
err := fs.Parse([]string{"--help"})
if err != nil && !errors.Is(err, flag.ErrHelp) {
cmd := buildRootCommand(testutils.NewTestConfig())
err := cmd.Run(context.Background(), []string{"goyco", "--help"})
if err != nil {
t.Errorf("Expected help flag parsing, got error: %v", err)
}
})
if !*showHelp {
t.Error("Expected help flag to be true")
t.Run("json flag", func(t *testing.T) {
cmd := buildRootCommand(testutils.NewTestConfig())
err := cmd.Run(context.Background(), []string{"goyco", "--json"})
if err != nil {
t.Errorf("Expected json flag parsing, got error: %v", err)
}
if !commands.IsJSONOutput() {
t.Error("Expected json output to be enabled")
}
commands.SetJSONOutput(false)
})
t.Run("command dispatch", func(t *testing.T) {
cfg := testutils.NewTestConfig()
err := dispatchCommand(cfg, "unknown", []string{})
cmd := buildRootCommand(cfg)
err := cmd.Run(context.Background(), []string{"goyco", "unknown"})
if err == nil {
t.Error("Expected error for unknown command")
}
err = dispatchCommand(cfg, "help", []string{})
cmd = buildRootCommand(cfg)
err = cmd.Run(context.Background(), []string{"goyco", "help"})
if err != nil {
t.Errorf("Help command should not error: %v", err)
}
@@ -364,6 +372,26 @@ func TestServerInitializationFlow(t *testing.T) {
if srv.Handler == nil {
t.Error("Expected server handler to be set")
}
expectedAddr := cfg.Server.Host + ":" + cfg.Server.Port
if srv.Addr != expectedAddr {
t.Errorf("Expected server address to be %q, got %q", expectedAddr, srv.Addr)
}
if srv.ReadTimeout != cfg.Server.ReadTimeout {
t.Errorf("Expected ReadTimeout to be %v, got %v", cfg.Server.ReadTimeout, srv.ReadTimeout)
}
if srv.WriteTimeout != cfg.Server.WriteTimeout {
t.Errorf("Expected WriteTimeout to be %v, got %v", cfg.Server.WriteTimeout, srv.WriteTimeout)
}
if srv.IdleTimeout != cfg.Server.IdleTimeout {
t.Errorf("Expected IdleTimeout to be %v, got %v", cfg.Server.IdleTimeout, srv.IdleTimeout)
}
if srv.MaxHeaderBytes != cfg.Server.MaxHeaderBytes {
t.Errorf("Expected MaxHeaderBytes to be %d, got %d", cfg.Server.MaxHeaderBytes, srv.MaxHeaderBytes)
}
testServer := httptest.NewServer(srv.Handler)
defer testServer.Close()

View File

@@ -324,7 +324,7 @@ const docTemplate = `{
"200": {
"description": "Authentication successful",
"schema": {
"$ref": "#/definitions/handlers.AuthTokensResponse"
"$ref": "#/definitions/handlers.AuthResponse"
}
},
"400": {
@@ -513,7 +513,7 @@ const docTemplate = `{
"200": {
"description": "Token refreshed successfully",
"schema": {
"$ref": "#/definitions/handlers.AuthTokensResponse"
"$ref": "#/definitions/handlers.AuthResponse"
}
},
"400": {
@@ -2064,63 +2064,6 @@ const docTemplate = `{
}
}
},
"handlers.AuthTokensDetail": {
"type": "object",
"properties": {
"access_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
},
"refresh_token": {
"type": "string",
"example": "f94d4ddc7d9b4fcb9d3a2c44c400b780c3e1f1a5c2b7d4e6a0b1c2d3e4f5a6b7"
},
"user": {
"$ref": "#/definitions/handlers.AuthUserSummary"
}
}
},
"handlers.AuthTokensResponse": {
"type": "object",
"properties": {
"data": {
"$ref": "#/definitions/handlers.AuthTokensDetail"
},
"message": {
"type": "string",
"example": "Authentication successful"
},
"success": {
"type": "boolean",
"example": true
}
}
},
"handlers.AuthUserSummary": {
"type": "object",
"properties": {
"email": {
"type": "string",
"example": "jane@example.com"
},
"email_verified": {
"type": "boolean",
"example": true
},
"id": {
"type": "integer",
"example": 42
},
"locked": {
"type": "boolean",
"example": false
},
"username": {
"type": "string",
"example": "janedoe"
}
}
},
"handlers.CommonResponse": {
"type": "object",
"properties": {
@@ -2186,7 +2129,7 @@ const docTemplate = `{
// SwaggerInfo holds exported Swagger Info so clients can modify it
var SwaggerInfo = &swag.Spec{
Version: "0.1.0",
Version: "0.1.1",
Host: "localhost:8080",
BasePath: "/api",
Schemes: []string{"http"},

View File

@@ -14,7 +14,7 @@
"name": "GPLv3",
"url": "https://www.gnu.org/licenses/gpl-3.0.html"
},
"version": "0.1.0"
"version": "0.1.1"
},
"host": "localhost:8080",
"basePath": "/api",
@@ -321,7 +321,7 @@
"200": {
"description": "Authentication successful",
"schema": {
"$ref": "#/definitions/handlers.AuthTokensResponse"
"$ref": "#/definitions/handlers.AuthResponse"
}
},
"400": {
@@ -510,7 +510,7 @@
"200": {
"description": "Token refreshed successfully",
"schema": {
"$ref": "#/definitions/handlers.AuthTokensResponse"
"$ref": "#/definitions/handlers.AuthResponse"
}
},
"400": {
@@ -2061,63 +2061,6 @@
}
}
},
"handlers.AuthTokensDetail": {
"type": "object",
"properties": {
"access_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
},
"refresh_token": {
"type": "string",
"example": "f94d4ddc7d9b4fcb9d3a2c44c400b780c3e1f1a5c2b7d4e6a0b1c2d3e4f5a6b7"
},
"user": {
"$ref": "#/definitions/handlers.AuthUserSummary"
}
}
},
"handlers.AuthTokensResponse": {
"type": "object",
"properties": {
"data": {
"$ref": "#/definitions/handlers.AuthTokensDetail"
},
"message": {
"type": "string",
"example": "Authentication successful"
},
"success": {
"type": "boolean",
"example": true
}
}
},
"handlers.AuthUserSummary": {
"type": "object",
"properties": {
"email": {
"type": "string",
"example": "jane@example.com"
},
"email_verified": {
"type": "boolean",
"example": true
},
"id": {
"type": "integer",
"example": 42
},
"locked": {
"type": "boolean",
"example": false
},
"username": {
"type": "string",
"example": "janedoe"
}
}
},
"handlers.CommonResponse": {
"type": "object",
"properties": {

View File

@@ -171,46 +171,6 @@ definitions:
success:
type: boolean
type: object
handlers.AuthTokensDetail:
properties:
access_token:
example: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
type: string
refresh_token:
example: f94d4ddc7d9b4fcb9d3a2c44c400b780c3e1f1a5c2b7d4e6a0b1c2d3e4f5a6b7
type: string
user:
$ref: '#/definitions/handlers.AuthUserSummary'
type: object
handlers.AuthTokensResponse:
properties:
data:
$ref: '#/definitions/handlers.AuthTokensDetail'
message:
example: Authentication successful
type: string
success:
example: true
type: boolean
type: object
handlers.AuthUserSummary:
properties:
email:
example: jane@example.com
type: string
email_verified:
example: true
type: boolean
id:
example: 42
type: integer
locked:
example: false
type: boolean
username:
example: janedoe
type: string
type: object
handlers.CommonResponse:
properties:
data: {}
@@ -261,7 +221,7 @@ info:
name: GPLv3
url: https://www.gnu.org/licenses/gpl-3.0.html
title: Goyco API
version: 0.1.0
version: 0.1.1
paths:
/api:
get:
@@ -459,7 +419,7 @@ paths:
"200":
description: Authentication successful
schema:
$ref: '#/definitions/handlers.AuthTokensResponse'
$ref: '#/definitions/handlers.AuthResponse'
"400":
description: Invalid request data or validation failed
schema:
@@ -580,7 +540,7 @@ paths:
"200":
description: Token refreshed successfully
schema:
$ref: '#/definitions/handlers.AuthTokensResponse'
$ref: '#/definitions/handlers.AuthResponse'
"400":
description: Invalid request body or missing refresh token
schema:

1
go.mod
View File

@@ -12,6 +12,7 @@ require (
github.com/stretchr/testify v1.11.1
github.com/swaggo/http-swagger v1.3.4
github.com/swaggo/swag v1.16.6
github.com/urfave/cli/v3 v3.6.1
golang.org/x/crypto v0.43.0
golang.org/x/net v0.46.0
gorm.io/driver/postgres v1.6.0

2
go.sum
View File

@@ -116,3 +116,5 @@ gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
github.com/urfave/cli/v3 v3.6.1 h1:j8Qq8NyUawj/7rTYdBGrxcH7A/j7/G8Q5LhWEW4G3Mo=
github.com/urfave/cli/v3 v3.6.1/go.mod h1:ysVLtOEmg2tOy6PknnYVhDoouyC/6N42TMeoMzskhso=

View File

@@ -0,0 +1,23 @@
package dto
import (
"goyco/internal/services"
)
type AuthResponseDTO struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
User UserDTO `json:"user"`
}
func ToAuthResponseDTO(result *services.AuthResult) AuthResponseDTO {
if result == nil {
return AuthResponseDTO{}
}
return AuthResponseDTO{
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
User: ToUserDTO(result.User),
}
}

View File

@@ -96,3 +96,7 @@ func ToSearchPostListDTO(posts []database.Post, query string, limit, offset int)
Offset: offset,
}
}
type TitleResponseDTO struct {
Title string `json:"title"`
}

View File

@@ -117,3 +117,9 @@ func ToRegistrationResponseDTO(user *database.User, verificationSent bool) Regis
type AccountDeletionResponseDTO struct {
PostsDeleted bool `json:"posts_deleted"`
}
type MessageResponseDTO struct {
Message string `json:"message"`
}
type EmptyResponseDTO struct{}

View File

@@ -1,89 +1,35 @@
package fuzz
import (
"sync"
"goyco/internal/database"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
var (
fuzzDBOnce sync.Once
fuzzDB *gorm.DB
fuzzDBErr error
)
func GetFuzzDB() (*gorm.DB, error) {
fuzzDBOnce.Do(func() {
dbName := "file:memdb_fuzz?mode=memory&cache=shared&_journal_mode=WAL&_synchronous=NORMAL"
fuzzDB, fuzzDBErr = gorm.Open(sqlite.Open(dbName), &gorm.Config{
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if fuzzDBErr == nil {
fuzzDBErr = fuzzDB.Exec(`
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
email TEXT UNIQUE NOT NULL,
password TEXT NOT NULL,
email_verified INTEGER DEFAULT 0 NOT NULL,
email_verified_at DATETIME,
email_verification_token TEXT,
email_verification_sent_at DATETIME,
password_reset_token TEXT,
password_reset_sent_at DATETIME,
password_reset_expires_at DATETIME,
locked INTEGER DEFAULT 0,
session_version INTEGER DEFAULT 1 NOT NULL,
created_at DATETIME,
updated_at DATETIME,
deleted_at DATETIME
);
CREATE TABLE IF NOT EXISTS posts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL,
url TEXT UNIQUE,
content TEXT,
author_id INTEGER,
author_name TEXT,
up_votes INTEGER DEFAULT 0,
down_votes INTEGER DEFAULT 0,
score INTEGER DEFAULT 0,
created_at DATETIME,
updated_at DATETIME,
deleted_at DATETIME,
FOREIGN KEY(author_id) REFERENCES users(id)
);
CREATE TABLE IF NOT EXISTS votes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER,
post_id INTEGER NOT NULL,
type TEXT NOT NULL,
vote_hash TEXT,
created_at DATETIME,
updated_at DATETIME,
FOREIGN KEY(user_id) REFERENCES users(id),
FOREIGN KEY(post_id) REFERENCES posts(id)
);
CREATE TABLE IF NOT EXISTS account_deletion_requests (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
token_hash TEXT UNIQUE NOT NULL,
expires_at DATETIME NOT NULL,
created_at DATETIME,
FOREIGN KEY(user_id) REFERENCES users(id)
);
CREATE TABLE IF NOT EXISTS refresh_tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
token_hash TEXT UNIQUE NOT NULL,
expires_at DATETIME NOT NULL,
created_at DATETIME,
FOREIGN KEY(user_id) REFERENCES users(id)
);
`).Error
if err != nil {
return nil, err
}
})
return fuzzDB, fuzzDBErr
if execErr := db.Exec("PRAGMA foreign_keys = ON").Error; execErr != nil {
return nil, execErr
}
err = db.AutoMigrate(
&database.User{},
&database.Post{},
&database.Vote{},
&database.AccountDeletionRequest{},
&database.RefreshToken{},
)
if err != nil {
return nil, err
}
return db, nil
}

View File

@@ -1704,8 +1704,11 @@ func TestGetFuzzDB(t *testing.T) {
if err2 != nil {
t.Fatalf("Second GetFuzzDB call failed: %v", err2)
}
if db2 != db {
t.Fatal("GetFuzzDB should return the same database instance")
if db2 == nil {
t.Fatal("Second GetFuzzDB returned nil database")
}
if db2 == db {
t.Fatal("GetFuzzDB should return a new database instance for each call")
}
}

View File

@@ -75,7 +75,7 @@ func (h *APIHandler) GetAPIInfo(w http.ResponseWriter, r *http.Request) {
apiInfo := map[string]any{
"name": fmt.Sprintf("%s API", h.config.App.Title),
"version": version.Version,
"version": version.GetVersion(),
"description": "Y Combinator-style news board API",
"endpoints": map[string]any{
"authentication": map[string]any{
@@ -145,7 +145,7 @@ func (h *APIHandler) GetHealth(w http.ResponseWriter, r *http.Request) {
if h.healthChecker != nil {
health := h.healthChecker.CheckHealth()
health["version"] = version.Version
health["version"] = version.GetVersion()
SendSuccessResponse(w, "Health check successful", health)
return
}
@@ -155,7 +155,7 @@ func (h *APIHandler) GetHealth(w http.ResponseWriter, r *http.Request) {
health := map[string]any{
"status": "healthy",
"timestamp": currentTimestamp,
"version": version.Version,
"version": version.GetVersion(),
"services": map[string]any{
"database": "connected",
"api": "running",
@@ -230,7 +230,7 @@ func (h *APIHandler) GetMetrics(w http.ResponseWriter, r *http.Request) {
},
"system": map[string]any{
"timestamp": time.Now().UTC().Format(time.RFC3339),
"version": version.Version,
"version": version.GetVersion(),
},
}

View File

@@ -44,26 +44,6 @@ type AuthHandler struct {
type AuthResponse = CommonResponse
type AuthTokensResponse struct {
Success bool `json:"success" example:"true"`
Message string `json:"message" example:"Authentication successful"`
Data AuthTokensDetail `json:"data"`
}
type AuthTokensDetail struct {
AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
RefreshToken string `json:"refresh_token" example:"f94d4ddc7d9b4fcb9d3a2c44c400b780c3e1f1a5c2b7d4e6a0b1c2d3e4f5a6b7"`
User AuthUserSummary `json:"user"`
}
type AuthUserSummary struct {
ID uint `json:"id" example:"42"`
Username string `json:"username" example:"janedoe"`
Email string `json:"email" example:"jane@example.com"`
EmailVerified bool `json:"email_verified" example:"true"`
Locked bool `json:"locked" example:"false"`
}
func NewAuthHandler(authService AuthServiceInterface, userRepo repositories.UserRepository) *AuthHandler {
return &AuthHandler{
authService: authService,
@@ -77,28 +57,28 @@ func NewAuthHandler(authService AuthServiceInterface, userRepo repositories.User
// @Accept json
// @Produce json
// @Param request body dto.LoginRequest true "Login credentials"
// @Success 200 {object} AuthTokensResponse "Authentication successful"
// @Success 200 {object} AuthResponse "Authentication successful"
// @Failure 400 {object} AuthResponse "Invalid request data or validation failed"
// @Failure 401 {object} AuthResponse "Invalid credentials"
// @Failure 403 {object} AuthResponse "Account is locked"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/login [post]
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
req, ok := GetValidatedDTO[dto.LoginRequest](r)
request, ok := GetValidatedDTO[dto.LoginRequest](w, r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
username := security.SanitizeUsername(req.Username)
password := strings.TrimSpace(req.Password)
username := security.SanitizeUsername(request.Username)
password := strings.TrimSpace(request.Password)
result, err := h.authService.Login(username, password)
if !HandleServiceError(w, err, "Authentication failed", http.StatusInternalServerError) {
return
}
SendSuccessResponse(w, "Authentication successful", result)
responseDTO := dto.ToAuthResponseDTO(result)
SendSuccessResponse(w, "Authentication successful", responseDTO)
}
// @Summary Register a new user
@@ -113,31 +93,16 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/register [post]
func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
req, ok := GetValidatedDTO[dto.RegisterRequest](r)
request, ok := GetValidatedDTO[dto.RegisterRequest](w, r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
username := strings.TrimSpace(req.Username)
email := strings.TrimSpace(req.Email)
password := strings.TrimSpace(req.Password)
username := strings.TrimSpace(request.Username)
email := strings.TrimSpace(request.Email)
password := strings.TrimSpace(request.Password)
username = security.SanitizeUsername(username)
if err := validation.ValidateUsername(username); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
if err := validation.ValidateEmail(email); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
if err := validation.ValidatePassword(password); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
result, err := h.authService.Register(username, email, password)
if err != nil {
@@ -196,13 +161,12 @@ func (h *AuthHandler) ConfirmEmail(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} AuthResponse
// @Router /api/auth/resend-verification [post]
func (h *AuthHandler) ResendVerificationEmail(w http.ResponseWriter, r *http.Request) {
req, ok := GetValidatedDTO[dto.ResendVerificationRequest](r)
request, ok := GetValidatedDTO[dto.ResendVerificationRequest](w, r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
email := strings.TrimSpace(req.Email)
email := strings.TrimSpace(request.Email)
if email == "" {
SendErrorResponse(w, "Email address is required", http.StatusBadRequest)
@@ -228,9 +192,10 @@ func (h *AuthHandler) ResendVerificationEmail(w http.ResponseWriter, r *http.Req
return
}
SendSuccessResponse(w, "Verification email sent successfully", map[string]any{
"message": "Check your inbox for the verification link",
})
responseDTO := dto.MessageResponseDTO{
Message: "Check your inbox for the verification link",
}
SendSuccessResponse(w, "Verification email sent successfully", responseDTO)
}
// @Summary Get current user profile
@@ -269,13 +234,12 @@ func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) {
// @Failure 400 {object} AuthResponse "Invalid request data"
// @Router /api/auth/forgot-password [post]
func (h *AuthHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Request) {
req, ok := GetValidatedDTO[dto.ForgotPasswordRequest](r)
request, ok := GetValidatedDTO[dto.ForgotPasswordRequest](w, r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
usernameOrEmail := strings.TrimSpace(req.UsernameOrEmail)
usernameOrEmail := strings.TrimSpace(request.UsernameOrEmail)
if usernameOrEmail == "" {
SendErrorResponse(w, "Username or email is required", http.StatusBadRequest)
@@ -285,7 +249,7 @@ func (h *AuthHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Reques
if err := h.authService.RequestPasswordReset(usernameOrEmail); err != nil {
}
SendSuccessResponse(w, "If an account with that username or email exists, we've sent a password reset link.", nil)
SendSuccessResponse(w, "If an account with that username or email exists, we've sent a password reset link.", dto.EmptyResponseDTO{})
}
// @Summary Reset password
@@ -299,25 +263,19 @@ func (h *AuthHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Reques
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/reset-password [post]
func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
req, ok := GetValidatedDTO[dto.ResetPasswordRequest](r)
request, ok := GetValidatedDTO[dto.ResetPasswordRequest](w, r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
token := strings.TrimSpace(req.Token)
newPassword := strings.TrimSpace(req.NewPassword)
token := strings.TrimSpace(request.Token)
newPassword := strings.TrimSpace(request.NewPassword)
if token == "" {
SendErrorResponse(w, "Token is required", http.StatusBadRequest)
return
}
if err := validation.ValidatePassword(newPassword); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
if err := h.authService.ResetPassword(token, newPassword); err != nil {
switch {
case strings.Contains(err.Error(), "expired"):
@@ -330,7 +288,7 @@ func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
return
}
SendSuccessResponse(w, "Password reset successfully. You can now sign in with your new password.", nil)
SendSuccessResponse(w, "Password reset successfully. You can now sign in with your new password.", dto.EmptyResponseDTO{})
}
// @Summary Update email address
@@ -353,17 +311,12 @@ func (h *AuthHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
return
}
req, ok := GetValidatedDTO[dto.UpdateEmailRequest](r)
request, ok := GetValidatedDTO[dto.UpdateEmailRequest](w, r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
email := strings.TrimSpace(req.Email)
if err := validation.ValidateEmail(email); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
email := strings.TrimSpace(request.Email)
user, err := h.authService.UpdateEmail(userID, email)
if err != nil {
@@ -403,17 +356,12 @@ func (h *AuthHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
return
}
req, ok := GetValidatedDTO[dto.UpdateUsernameRequest](r)
request, ok := GetValidatedDTO[dto.UpdateUsernameRequest](w, r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
username := strings.TrimSpace(req.Username)
if err := validation.ValidateUsername(username); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
username := strings.TrimSpace(request.Username)
user, err := h.authService.UpdateUsername(userID, username)
if err != nil {
@@ -448,24 +396,13 @@ func (h *AuthHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) {
return
}
req, ok := GetValidatedDTO[dto.UpdatePasswordRequest](r)
request, ok := GetValidatedDTO[dto.UpdatePasswordRequest](w, r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
currentPassword := strings.TrimSpace(req.CurrentPassword)
newPassword := strings.TrimSpace(req.NewPassword)
if currentPassword == "" {
SendErrorResponse(w, "Current password is required", http.StatusBadRequest)
return
}
if err := validation.ValidatePassword(newPassword); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
currentPassword := strings.TrimSpace(request.CurrentPassword)
newPassword := strings.TrimSpace(request.NewPassword)
user, err := h.authService.UpdatePassword(userID, currentPassword, newPassword)
if err != nil {
@@ -508,7 +445,7 @@ func (h *AuthHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
return
}
SendSuccessResponse(w, "Check your inbox for a confirmation link to finish deleting your account.", nil)
SendSuccessResponse(w, "Check your inbox for a confirmation link to finish deleting your account.", dto.EmptyResponseDTO{})
}
// @Summary Confirm account deletion
@@ -523,20 +460,19 @@ func (h *AuthHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/account/confirm [post]
func (h *AuthHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Request) {
req, ok := GetValidatedDTO[dto.ConfirmAccountDeletionRequest](r)
request, ok := GetValidatedDTO[dto.ConfirmAccountDeletionRequest](w, r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
token := strings.TrimSpace(req.Token)
token := strings.TrimSpace(request.Token)
if token == "" {
SendErrorResponse(w, "Deletion token is required", http.StatusBadRequest)
return
}
if err := h.authService.ConfirmAccountDeletionWithPosts(token, req.DeletePosts); err != nil {
if err := h.authService.ConfirmAccountDeletionWithPosts(token, request.DeletePosts); err != nil {
switch {
case errors.Is(err, services.ErrInvalidDeletionToken):
SendErrorResponse(w, "This deletion link is invalid or has expired.", http.StatusBadRequest)
@@ -544,7 +480,7 @@ func (h *AuthHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Requ
SendErrorResponse(w, "Account deletion isn't available right now because email delivery is disabled.", http.StatusServiceUnavailable)
case errors.Is(err, services.ErrDeletionEmailFailed):
responseDTO := dto.AccountDeletionResponseDTO{
PostsDeleted: req.DeletePosts,
PostsDeleted: request.DeletePosts,
}
SendSuccessResponse(w, "Your account has been deleted, but we couldn't send the confirmation email.", responseDTO)
default:
@@ -554,7 +490,7 @@ func (h *AuthHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Requ
}
responseDTO := dto.AccountDeletionResponseDTO{
PostsDeleted: req.DeletePosts,
PostsDeleted: request.DeletePosts,
}
SendSuccessResponse(w, "Your account has been deleted.", responseDTO)
}
@@ -569,7 +505,7 @@ func (h *AuthHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Requ
// @Failure 401 {object} AuthResponse "Authentication required"
// @Router /api/auth/logout [post]
func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
SendSuccessResponse(w, "Logged out successfully", nil)
SendSuccessResponse(w, "Logged out successfully", dto.EmptyResponseDTO{})
}
// @Summary Refresh access token
@@ -578,30 +514,30 @@ func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
// @Accept json
// @Produce json
// @Param request body dto.RefreshTokenRequest true "Refresh token data"
// @Success 200 {object} AuthTokensResponse "Token refreshed successfully"
// @Success 200 {object} AuthResponse "Token refreshed successfully"
// @Failure 400 {object} AuthResponse "Invalid request body or missing refresh token"
// @Failure 401 {object} AuthResponse "Invalid or expired refresh token"
// @Failure 403 {object} AuthResponse "Account is locked"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/refresh [post]
func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
req, ok := GetValidatedDTO[dto.RefreshTokenRequest](r)
request, ok := GetValidatedDTO[dto.RefreshTokenRequest](w, r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
if req.RefreshToken == "" {
if request.RefreshToken == "" {
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
return
}
result, err := h.authService.RefreshAccessToken(req.RefreshToken)
result, err := h.authService.RefreshAccessToken(request.RefreshToken)
if !HandleServiceError(w, err, "Token refresh failed", http.StatusInternalServerError) {
return
}
SendSuccessResponse(w, "Token refreshed successfully", result)
responseDTO := dto.ToAuthResponseDTO(result)
SendSuccessResponse(w, "Token refreshed successfully", responseDTO)
}
// @Summary Revoke refresh token
@@ -617,24 +553,23 @@ func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/revoke [post]
func (h *AuthHandler) RevokeToken(w http.ResponseWriter, r *http.Request) {
req, ok := GetValidatedDTO[dto.RevokeTokenRequest](r)
request, ok := GetValidatedDTO[dto.RevokeTokenRequest](w, r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
if req.RefreshToken == "" {
if request.RefreshToken == "" {
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
return
}
err := h.authService.RevokeRefreshToken(req.RefreshToken)
err := h.authService.RevokeRefreshToken(request.RefreshToken)
if err != nil {
SendErrorResponse(w, "Failed to revoke token", http.StatusInternalServerError)
return
}
SendSuccessResponse(w, "Token revoked successfully", nil)
SendSuccessResponse(w, "Token revoked successfully", dto.EmptyResponseDTO{})
}
// @Summary Revoke all user tokens
@@ -659,7 +594,7 @@ func (h *AuthHandler) RevokeAllTokens(w http.ResponseWriter, r *http.Request) {
return
}
SendSuccessResponse(w, "All tokens revoked successfully", nil)
SendSuccessResponse(w, "All tokens revoked successfully", dto.EmptyResponseDTO{})
}
func (h *AuthHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {

View File

@@ -693,7 +693,7 @@ func TestAuthHandlerUpdateEmail(t *testing.T) {
userID: 1,
mockSetup: func(repo *testutils.UserRepositoryStub) {},
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request",
expectedError: "Invalid JSON",
},
{
name: "empty email",
@@ -701,7 +701,7 @@ func TestAuthHandlerUpdateEmail(t *testing.T) {
userID: 1,
mockSetup: func(repo *testutils.UserRepositoryStub) {},
expectedStatus: http.StatusBadRequest,
expectedError: "Email is required",
expectedError: "email is required",
},
{
name: "email already taken",
@@ -788,7 +788,7 @@ func TestAuthHandlerUpdateUsername(t *testing.T) {
userID: 1,
mockSetup: func(repo *testutils.UserRepositoryStub) {},
expectedStatus: http.StatusBadRequest,
expectedError: "Username is required",
expectedError: "username is required",
},
{
name: "username already taken",
@@ -876,7 +876,7 @@ func TestAuthHandlerUpdatePassword(t *testing.T) {
userID: 1,
mockSetup: func(repo *testutils.UserRepositoryStub) {},
expectedStatus: http.StatusBadRequest,
expectedError: "Current password is required",
expectedError: "current_password is required",
},
{
name: "empty new password",
@@ -884,7 +884,7 @@ func TestAuthHandlerUpdatePassword(t *testing.T) {
userID: 1,
mockSetup: func(repo *testutils.UserRepositoryStub) {},
expectedStatus: http.StatusBadRequest,
expectedError: "Password is required",
expectedError: "new_password is required",
},
{
name: "short new password",
@@ -892,7 +892,7 @@ func TestAuthHandlerUpdatePassword(t *testing.T) {
userID: 1,
mockSetup: func(repo *testutils.UserRepositoryStub) {},
expectedStatus: http.StatusBadRequest,
expectedError: "Password must be at least 8 characters long",
expectedError: "new_password must be at least 8 characters",
},
{
name: "incorrect current password",
@@ -1042,13 +1042,13 @@ func TestAuthHandlerResendVerificationEmail(t *testing.T) {
name: "invalid json",
body: "not-json",
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request",
expectedError: "Invalid JSON",
},
{
name: "missing email",
body: `{}`,
expectedStatus: http.StatusBadRequest,
expectedError: "Email address is required",
expectedError: "email is required",
},
{
name: "account not found",
@@ -1167,13 +1167,13 @@ func TestAuthHandlerConfirmAccountDeletion(t *testing.T) {
name: "invalid json",
body: "not-json",
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request",
expectedError: "Invalid JSON",
},
{
name: "missing token",
body: `{}`,
expectedStatus: http.StatusBadRequest,
expectedError: "Deletion token is required",
expectedError: "token is required",
},
{
name: "invalid token from service",

View File

@@ -15,6 +15,7 @@ import (
"goyco/internal/dto"
"goyco/internal/middleware"
"goyco/internal/services"
"goyco/internal/validation"
"github.com/go-chi/chi/v5"
"gorm.io/gorm"
@@ -272,13 +273,51 @@ func HandleServiceError(w http.ResponseWriter, err error, defaultMsg string, def
return false
}
func GetValidatedDTO[T any](r *http.Request) (*T, bool) {
func GetValidatedDTO[T any](w http.ResponseWriter, r *http.Request) (*T, bool) {
dtoVal := middleware.GetValidatedDTOFromContext(r.Context())
if dtoVal == nil {
dtoTypeInContext := middleware.GetDTOTypeFromContext(r.Context())
var dto *T
needsValidation := false
if dtoVal != nil {
var ok bool
dto, ok = dtoVal.(*T)
if !ok {
return nil, false
}
dto, ok := dtoVal.(*T)
return dto, ok
if dtoTypeInContext == nil {
needsValidation = true
}
} else {
var decoded T
if err := json.NewDecoder(r.Body).Decode(&decoded); err != nil {
SendErrorResponse(w, "Invalid JSON", http.StatusBadRequest)
return nil, false
}
dto = &decoded
needsValidation = true
}
if needsValidation {
if err := validation.ValidateStruct(dto); err != nil {
var errorMessages []string
if structErr, ok := err.(*validation.StructValidationError); ok {
errorMessages = make([]string, len(structErr.Errors))
for i, fieldError := range structErr.Errors {
errorMessages[i] = fieldError.Message
}
} else {
errorMessages = []string{err.Error()}
}
errorMsg := strings.Join(errorMessages, "; ")
SendErrorResponse(w, errorMsg, http.StatusBadRequest)
return nil, false
}
}
return dto, true
}
func WithValidation[T any](validationMiddleware func(http.Handler) http.Handler, handler http.HandlerFunc) http.HandlerFunc {

View File

@@ -35,25 +35,25 @@ func FuzzJSONParsing(f *testing.F) {
func FuzzURLParsing(f *testing.F) {
helper := fuzz.NewFuzzTestHelper()
helper.RunBasicFuzzTest(f, func(t *testing.T, input string) {
sanitized := ""
var sanitized strings.Builder
sanitized.Grow(len(input))
sanitizedLen := 0
for _, char := range input {
if (char >= 'a' && char <= 'z') || (char >= 'A' && char <= 'Z') ||
(char >= '0' && char <= '9') || char == '-' || char == '_' {
sanitized += string(char)
sanitized.WriteRune(char)
sanitizedLen++
if sanitizedLen >= 20 {
break
}
}
}
if len(sanitized) > 20 {
sanitized = sanitized[:20]
}
if len(sanitized) == 0 {
if sanitizedLen == 0 {
return
}
url := "/api/posts/" + sanitized
url := "/api/posts/" + sanitized.String()
req := httptest.NewRequest("GET", url, nil)
pathParts := strings.Split(req.URL.Path, "/")
@@ -67,46 +67,52 @@ func FuzzURLParsing(f *testing.F) {
func FuzzQueryParameters(f *testing.F) {
helper := fuzz.NewFuzzTestHelper()
helper.RunBasicFuzzTest(f, func(t *testing.T, input string) {
if !utf8.ValidString(input) {
return
}
sanitized := ""
var sanitized strings.Builder
sanitized.Grow(len(input))
sanitizedLen := 0
for _, char := range input {
if char >= 32 && char <= 126 {
switch char {
case ' ', '\n', '\r', '\t':
continue
case '&':
sanitized += "%26"
sanitized.WriteString("%26")
sanitizedLen += 3
case '=':
sanitized += "%3D"
sanitized.WriteString("%3D")
sanitizedLen += 3
case '?':
sanitized += "%3F"
sanitized.WriteString("%3F")
sanitizedLen += 3
case '#':
sanitized += "%23"
sanitized.WriteString("%23")
sanitizedLen += 3
case '/':
sanitized += "%2F"
sanitized.WriteString("%2F")
sanitizedLen += 3
case '\\':
sanitized += "%5C"
sanitized.WriteString("%5C")
sanitizedLen += 3
default:
sanitized += string(char)
sanitized.WriteRune(char)
sanitizedLen++
}
if sanitizedLen >= 100 {
break
}
}
}
if len(sanitized) > 100 {
sanitized = sanitized[:100]
}
if len(sanitized) == 0 {
if sanitizedLen == 0 {
return
}
query := "?q=" + sanitized + "&limit=10&offset=0"
query := "?q=" + sanitized.String() + "&limit=10&offset=0"
req := httptest.NewRequest("GET", "/api/posts/search"+query, nil)
q := req.URL.Query().Get("q")

View File

@@ -12,7 +12,6 @@ import (
"goyco/internal/repositories"
"goyco/internal/security"
"goyco/internal/services"
"goyco/internal/validation"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgconn"
@@ -63,13 +62,7 @@ func (h *PostHandler) GetPosts(w http.ResponseWriter, r *http.Request) {
return
}
postDTOs := dto.ToPostDTOs(posts)
responseDTO := dto.PostListDTO{
Posts: postDTOs,
Count: len(postDTOs),
Limit: limit,
Offset: offset,
}
responseDTO := dto.ToPostListDTO(posts, limit, offset)
SendSuccessResponse(w, "Posts retrieved successfully", responseDTO)
}
@@ -116,9 +109,8 @@ func (h *PostHandler) GetPost(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} PostResponse "Internal server error"
// @Router /api/posts [post]
func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
req, ok := GetValidatedDTO[dto.CreatePostRequest](r)
request, ok := GetValidatedDTO[dto.CreatePostRequest](w, r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -127,9 +119,9 @@ func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
return
}
title := security.SanitizeInput(req.Title)
url := security.SanitizeURL(req.URL)
content := security.SanitizePostContent(req.Content)
title := security.SanitizeInput(request.Title)
url := security.SanitizeURL(request.URL)
content := security.SanitizePostContent(request.Content)
if url == "" {
SendErrorResponse(w, "Invalid URL", http.StatusBadRequest)
@@ -230,14 +222,7 @@ func (h *PostHandler) SearchPosts(w http.ResponseWriter, r *http.Request) {
return
}
postDTOs := dto.ToPostDTOs(posts)
responseDTO := dto.SearchPostListDTO{
Posts: postDTOs,
Count: len(postDTOs),
Query: query,
Limit: limit,
Offset: offset,
}
responseDTO := dto.ToSearchPostListDTO(posts, query, limit, offset)
SendSuccessResponse(w, "Search results retrieved successfully", responseDTO)
}
@@ -277,24 +262,13 @@ func (h *PostHandler) UpdatePost(w http.ResponseWriter, r *http.Request) {
return
}
req, ok := GetValidatedDTO[dto.UpdatePostRequest](r)
request, ok := GetValidatedDTO[dto.UpdatePostRequest](w, r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
title := security.SanitizeInput(req.Title)
content := security.SanitizePostContent(req.Content)
if err := validation.ValidateTitle(title); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
if err := validation.ValidateContent(content); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
title := security.SanitizeInput(request.Title)
content := security.SanitizePostContent(request.Content)
post.Title = title
post.Content = content
@@ -353,7 +327,7 @@ func (h *PostHandler) DeletePost(w http.ResponseWriter, r *http.Request) {
return
}
SendSuccessResponse(w, "Post deleted successfully", nil)
SendSuccessResponse(w, "Post deleted successfully", dto.EmptyResponseDTO{})
}
// @Summary Fetch title from URL
@@ -395,9 +369,10 @@ func (h *PostHandler) FetchTitleFromURL(w http.ResponseWriter, r *http.Request)
return
}
SendSuccessResponse(w, "Title fetched successfully", map[string]string{
"title": title,
})
responseDTO := dto.TitleResponseDTO{
Title: title,
}
SendSuccessResponse(w, "Title fetched successfully", responseDTO)
}
func translatePostCreateError(err error) (string, int) {

View File

@@ -10,14 +10,15 @@ import (
"strings"
"testing"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgconn"
"gorm.io/gorm"
"goyco/internal/database"
"goyco/internal/middleware"
"goyco/internal/repositories"
"goyco/internal/services"
"goyco/internal/testutils"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgconn"
"gorm.io/gorm"
)
func decodeHandlerResponse(t *testing.T, rr *httptest.ResponseRecorder) map[string]any {
@@ -310,7 +311,7 @@ func TestPostHandlerCreatePostValidation(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = createCreatePostRequest(`{"title":"ok","url":"https://example.com"}`)
request = createCreatePostRequest(`{"title":"okay","url":"https://example.com"}`)
handler.CreatePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
}
@@ -466,7 +467,7 @@ func TestPostHandlerUpdatePost(t *testing.T) {
}
},
expectedStatus: http.StatusBadRequest,
expectedError: "Title is required",
expectedError: "title is required",
},
{
name: "short title",
@@ -480,7 +481,7 @@ func TestPostHandlerUpdatePost(t *testing.T) {
}
},
expectedStatus: http.StatusBadRequest,
expectedError: "Title must be at least 3 characters",
expectedError: "title must be at least 3 characters",
},
}

View File

@@ -46,13 +46,7 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
return
}
userDTOs := dto.ToSanitizedUserDTOs(users)
responseDTO := dto.SanitizedUserListDTO{
Users: userDTOs,
Count: len(userDTOs),
Limit: limit,
Offset: offset,
}
responseDTO := dto.ToSanitizedUserListDTO(users, limit, offset)
SendSuccessResponse(w, "Users retrieved successfully", responseDTO)
}
@@ -99,28 +93,12 @@ func (h *UserHandler) GetUser(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} UserResponse "Internal server error"
// @Router /api/users [post]
func (h *UserHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
req, ok := GetValidatedDTO[dto.RegisterRequest](r)
request, ok := GetValidatedDTO[dto.RegisterRequest](w, r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
if err := validation.ValidateUsername(req.Username); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
if err := validation.ValidateEmail(req.Email); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
if err := validation.ValidatePassword(req.Password); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
result, err := h.authService.Register(req.Username, req.Email, req.Password)
result, err := h.authService.Register(request.Username, request.Email, request.Password)
if err != nil {
var validationErr *validation.ValidationError
if errors.As(err, &validationErr) {
@@ -164,13 +142,7 @@ func (h *UserHandler) GetUserPosts(w http.ResponseWriter, r *http.Request) {
return
}
postDTOs := dto.ToPostDTOs(posts)
responseDTO := dto.PostListDTO{
Posts: postDTOs,
Count: len(postDTOs),
Limit: limit,
Offset: offset,
}
responseDTO := dto.ToPostListDTO(posts, limit, offset)
SendSuccessResponse(w, "User posts retrieved successfully", responseDTO)
}

View File

@@ -78,14 +78,13 @@ func (h *VoteHandler) CastVote(w http.ResponseWriter, r *http.Request) {
return
}
req, ok := GetValidatedDTO[dto.CastVoteRequest](r)
request, ok := GetValidatedDTO[dto.CastVoteRequest](w, r)
if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return
}
var voteType database.VoteType
switch req.Type {
switch request.Type {
case "up":
voteType = database.VoteUp
case "down":

View File

@@ -1,12 +1,9 @@
package repositories
import (
"fmt"
"testing"
"goyco/internal/database"
"gorm.io/gorm"
)
func TestDatabase_AssertUserExists(t *testing.T) {
@@ -384,245 +381,3 @@ func TestDatabase_CreateTestAccountDeletionRequest(t *testing.T) {
}
})
}
func TestApplyPagination(t *testing.T) {
suite := NewTestSuite(t)
tests := []struct {
name string
limit int
offset int
setupQuery func(*gorm.DB) *gorm.DB
verifyPagination func(*testing.T, *gorm.DB, int, int)
}{
{
name: "limit > 0 and offset > 0",
limit: 10,
offset: 5,
setupQuery: func(db *gorm.DB) *gorm.DB {
return db.Model(&database.User{})
},
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
for i := 0; i < 20; i++ {
suite.CreateTestUser(
fmt.Sprintf("testuser_%d", i),
fmt.Sprintf("user%d@example.com", i),
"password123",
)
}
var users []database.User
result := query.Find(&users)
if result.Error != nil {
t.Fatalf("Query failed: %v", result.Error)
}
if len(users) != limit {
t.Errorf("Expected %d users, got %d", limit, len(users))
}
},
},
{
name: "limit > 0 and offset = 0",
limit: 5,
offset: 0,
setupQuery: func(db *gorm.DB) *gorm.DB {
return db.Model(&database.User{})
},
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
for i := 0; i < 10; i++ {
suite.CreateTestUser(
fmt.Sprintf("testuser_%d", i),
fmt.Sprintf("user%d@example.com", i),
"password123",
)
}
var users []database.User
result := query.Find(&users)
if result.Error != nil {
t.Fatalf("Query failed: %v", result.Error)
}
if len(users) != limit {
t.Errorf("Expected %d users, got %d", limit, len(users))
}
},
},
{
name: "limit = 0 (should not apply limit)",
limit: 0,
offset: 5,
setupQuery: func(db *gorm.DB) *gorm.DB {
return db.Model(&database.User{})
},
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
for i := 0; i < 10; i++ {
suite.CreateTestUser(
fmt.Sprintf("testuser_%d", i),
fmt.Sprintf("user%d@example.com", i),
"password123",
)
}
var users []database.User
result := query.Find(&users)
if result.Error != nil {
t.Fatalf("Query failed: %v", result.Error)
}
expected := 5
if len(users) != expected {
t.Errorf("Expected %d users with offset %d, got %d", expected, offset, len(users))
}
},
},
{
name: "offset = 0 (should not apply offset)",
limit: 10,
offset: 0,
setupQuery: func(db *gorm.DB) *gorm.DB {
return db.Model(&database.User{})
},
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
for i := 0; i < 15; i++ {
suite.CreateTestUser(
fmt.Sprintf("testuser_%d", i),
fmt.Sprintf("user%d@example.com", i),
"password123",
)
}
var users []database.User
result := query.Find(&users)
if result.Error != nil {
t.Fatalf("Query failed: %v", result.Error)
}
if len(users) != limit {
t.Errorf("Expected %d users, got %d", limit, len(users))
}
},
},
{
name: "limit = 0 and offset = 0 (should not apply pagination)",
limit: 0,
offset: 0,
setupQuery: func(db *gorm.DB) *gorm.DB {
return db.Model(&database.User{})
},
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
for i := 0; i < 10; i++ {
suite.CreateTestUser(
fmt.Sprintf("testuser_%d", i),
fmt.Sprintf("user%d@example.com", i),
"password123",
)
}
var users []database.User
result := query.Find(&users)
if result.Error != nil {
t.Fatalf("Query failed: %v", result.Error)
}
if len(users) != 10 {
t.Errorf("Expected all 10 users, got %d", len(users))
}
},
},
{
name: "negative limit (should not apply limit)",
limit: -5,
offset: 10,
setupQuery: func(db *gorm.DB) *gorm.DB {
return db.Model(&database.User{})
},
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
for i := 0; i < 20; i++ {
suite.CreateTestUser(
fmt.Sprintf("testuser_%d", i),
fmt.Sprintf("user%d@example.com", i),
"password123",
)
}
var users []database.User
result := query.Find(&users)
if result.Error != nil {
t.Fatalf("Query failed: %v", result.Error)
}
expected := 10
if len(users) != expected {
t.Errorf("Expected %d users with offset %d, got %d", expected, offset, len(users))
}
},
},
{
name: "negative offset (should not apply offset)",
limit: 10,
offset: -5,
setupQuery: func(db *gorm.DB) *gorm.DB {
return db.Model(&database.User{})
},
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
for i := 0; i < 15; i++ {
suite.CreateTestUser(
fmt.Sprintf("testuser_%d", i),
fmt.Sprintf("user%d@example.com", i),
"password123",
)
}
var users []database.User
result := query.Find(&users)
if result.Error != nil {
t.Fatalf("Query failed: %v", result.Error)
}
if len(users) != limit {
t.Errorf("Expected %d users, got %d", limit, len(users))
}
},
},
{
name: "large limit and offset values",
limit: 1000,
offset: 500,
setupQuery: func(db *gorm.DB) *gorm.DB {
return db.Model(&database.User{})
},
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
for i := 0; i < 2000; i++ {
suite.CreateTestUser(
fmt.Sprintf("testuser_%d", i),
fmt.Sprintf("user%d@example.com", i),
"password123",
)
}
var users []database.User
result := query.Find(&users)
if result.Error != nil {
t.Fatalf("Query failed: %v", result.Error)
}
if len(users) != limit {
t.Errorf("Expected %d users, got %d", limit, len(users))
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
suite.Reset()
baseQuery := tt.setupQuery(suite.DB)
paginatedQuery := ApplyPagination(baseQuery, tt.limit, tt.offset)
tt.verifyPagination(t, paginatedQuery, tt.limit, tt.offset)
})
}
}

View File

@@ -0,0 +1,252 @@
package repositories
import (
"fmt"
"testing"
"goyco/internal/database"
"gorm.io/gorm"
)
func TestApplyPagination(t *testing.T) {
suite := NewTestSuite(t)
tests := []struct {
name string
limit int
offset int
setupQuery func(*gorm.DB) *gorm.DB
verifyPagination func(*testing.T, *gorm.DB, int, int)
}{
{
name: "limit > 0 and offset > 0",
limit: 10,
offset: 5,
setupQuery: func(db *gorm.DB) *gorm.DB {
return db.Model(&database.User{})
},
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
for i := range 20 {
suite.CreateTestUser(
fmt.Sprintf("testuser_%d", i),
fmt.Sprintf("user%d@example.com", i),
"password123",
)
}
var users []database.User
result := query.Find(&users)
if result.Error != nil {
t.Fatalf("Query failed: %v", result.Error)
}
if len(users) != limit {
t.Errorf("Expected %d users, got %d", limit, len(users))
}
},
},
{
name: "limit > 0 and offset = 0",
limit: 5,
offset: 0,
setupQuery: func(db *gorm.DB) *gorm.DB {
return db.Model(&database.User{})
},
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
for i := range 10 {
suite.CreateTestUser(
fmt.Sprintf("testuser_%d", i),
fmt.Sprintf("user%d@example.com", i),
"password123",
)
}
var users []database.User
result := query.Find(&users)
if result.Error != nil {
t.Fatalf("Query failed: %v", result.Error)
}
if len(users) != limit {
t.Errorf("Expected %d users, got %d", limit, len(users))
}
},
},
{
name: "limit = 0 (should not apply limit)",
limit: 0,
offset: 5,
setupQuery: func(db *gorm.DB) *gorm.DB {
return db.Model(&database.User{})
},
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
for i := range 10 {
suite.CreateTestUser(
fmt.Sprintf("testuser_%d", i),
fmt.Sprintf("user%d@example.com", i),
"password123",
)
}
var users []database.User
result := query.Find(&users)
if result.Error != nil {
t.Fatalf("Query failed: %v", result.Error)
}
expected := 5
if len(users) != expected {
t.Errorf("Expected %d users with offset %d, got %d", expected, offset, len(users))
}
},
},
{
name: "offset = 0 (should not apply offset)",
limit: 10,
offset: 0,
setupQuery: func(db *gorm.DB) *gorm.DB {
return db.Model(&database.User{})
},
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
for i := range 15 {
suite.CreateTestUser(
fmt.Sprintf("testuser_%d", i),
fmt.Sprintf("user%d@example.com", i),
"password123",
)
}
var users []database.User
result := query.Find(&users)
if result.Error != nil {
t.Fatalf("Query failed: %v", result.Error)
}
if len(users) != limit {
t.Errorf("Expected %d users, got %d", limit, len(users))
}
},
},
{
name: "limit = 0 and offset = 0 (should not apply pagination)",
limit: 0,
offset: 0,
setupQuery: func(db *gorm.DB) *gorm.DB {
return db.Model(&database.User{})
},
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
for i := range 10 {
suite.CreateTestUser(
fmt.Sprintf("testuser_%d", i),
fmt.Sprintf("user%d@example.com", i),
"password123",
)
}
var users []database.User
result := query.Find(&users)
if result.Error != nil {
t.Fatalf("Query failed: %v", result.Error)
}
if len(users) != 10 {
t.Errorf("Expected all 10 users, got %d", len(users))
}
},
},
{
name: "negative limit (should not apply limit)",
limit: -5,
offset: 10,
setupQuery: func(db *gorm.DB) *gorm.DB {
return db.Model(&database.User{})
},
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
for i := range 20 {
suite.CreateTestUser(
fmt.Sprintf("testuser_%d", i),
fmt.Sprintf("user%d@example.com", i),
"password123",
)
}
var users []database.User
result := query.Find(&users)
if result.Error != nil {
t.Fatalf("Query failed: %v", result.Error)
}
expected := 10
if len(users) != expected {
t.Errorf("Expected %d users with offset %d, got %d", expected, offset, len(users))
}
},
},
{
name: "negative offset (should not apply offset)",
limit: 10,
offset: -5,
setupQuery: func(db *gorm.DB) *gorm.DB {
return db.Model(&database.User{})
},
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
for i := range 15 {
suite.CreateTestUser(
fmt.Sprintf("testuser_%d", i),
fmt.Sprintf("user%d@example.com", i),
"password123",
)
}
var users []database.User
result := query.Find(&users)
if result.Error != nil {
t.Fatalf("Query failed: %v", result.Error)
}
if len(users) != limit {
t.Errorf("Expected %d users, got %d", limit, len(users))
}
},
},
{
name: "large limit and offset values",
limit: 1000,
offset: 500,
setupQuery: func(db *gorm.DB) *gorm.DB {
return db.Model(&database.User{})
},
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
for i := range 2000 {
suite.CreateTestUser(
fmt.Sprintf("testuser_%d", i),
fmt.Sprintf("user%d@example.com", i),
"password123",
)
}
var users []database.User
result := query.Find(&users)
if result.Error != nil {
t.Fatalf("Query failed: %v", result.Error)
}
if len(users) != limit {
t.Errorf("Expected %d users, got %d", limit, len(users))
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
suite.Reset()
baseQuery := tt.setupQuery(suite.DB)
paginatedQuery := ApplyPagination(baseQuery, tt.limit, tt.offset)
tt.verifyPagination(t, paginatedQuery, tt.limit, tt.offset)
})
}
}

View File

@@ -3,51 +3,61 @@ package validation
import (
"strings"
"testing"
"goyco/internal/fuzz"
"unicode/utf8"
)
func FuzzValidateEmail(f *testing.F) {
helper := fuzz.NewFuzzTestHelper()
helper.RunValidationFuzzTest(f, ValidateEmail)
runValidationFuzzTest(f, ValidateEmail)
}
func FuzzValidateUsername(f *testing.F) {
helper := fuzz.NewFuzzTestHelper()
helper.RunValidationFuzzTest(f, ValidateUsername)
runValidationFuzzTest(f, ValidateUsername)
}
func FuzzValidatePassword(f *testing.F) {
helper := fuzz.NewFuzzTestHelper()
helper.RunValidationFuzzTest(f, ValidatePassword)
runValidationFuzzTest(f, ValidatePassword)
}
func FuzzValidateURL(f *testing.F) {
helper := fuzz.NewFuzzTestHelper()
helper.RunValidationFuzzTest(f, ValidateURL)
runValidationFuzzTest(f, ValidateURL)
}
func FuzzValidateTitle(f *testing.F) {
helper := fuzz.NewFuzzTestHelper()
helper.RunValidationFuzzTest(f, ValidateTitle)
runValidationFuzzTest(f, ValidateTitle)
}
func FuzzValidateContent(f *testing.F) {
helper := fuzz.NewFuzzTestHelper()
helper.RunValidationFuzzTest(f, ValidateContent)
runValidationFuzzTest(f, ValidateContent)
}
func FuzzValidateSearchQuery(f *testing.F) {
helper := fuzz.NewFuzzTestHelper()
helper.RunValidationFuzzTest(f, ValidateSearchQuery)
runValidationFuzzTest(f, ValidateSearchQuery)
}
func FuzzSanitizeString(f *testing.F) {
helper := fuzz.NewFuzzTestHelper()
helper.RunSanitizationFuzzTestWithValidation(f,
SanitizeString,
func(result string) bool {
return !containsNullBytes(result)
f.Add("test input")
f.Fuzz(func(t *testing.T, input string) {
if !utf8.ValidString(input) {
return
}
result := SanitizeString(input)
if !utf8.ValidString(result) {
t.Fatal("Sanitized result contains invalid UTF-8")
}
if containsNullBytes(result) {
t.Fatal("Sanitized result contains null bytes")
}
})
}
func runValidationFuzzTest(f *testing.F, validateFunc func(string) error) {
f.Add("test input")
f.Fuzz(func(t *testing.T, input string) {
if !utf8.ValidString(input) {
return
}
err := validateFunc(input)
_ = err
})
}

View File

@@ -175,6 +175,33 @@ type FieldValidationError struct {
Message string
}
func getFieldDisplayName(field reflect.StructField) string {
jsonTag := field.Tag.Get("json")
if jsonTag != "" && jsonTag != "-" {
if idx := strings.Index(jsonTag, ","); idx != -1 {
jsonTag = jsonTag[:idx]
}
if jsonTag != "" {
return jsonTag
}
}
return camelCaseToWords(field.Name)
}
func camelCaseToWords(s string) string {
if s == "" {
return s
}
var result strings.Builder
for i, r := range s {
if i > 0 && unicode.IsUpper(r) {
result.WriteRune(' ')
}
result.WriteRune(unicode.ToLower(r))
}
return result.String()
}
func ValidateStruct(s interface{}) error {
if s == nil {
return nil
@@ -232,9 +259,10 @@ func ValidateStruct(s interface{}) error {
tagName = tag
}
if err := validateField(field.Name, fieldVal, tagName, param, omitempty); err != nil {
displayName := getFieldDisplayName(field)
if err := validateField(displayName, fieldVal, tagName, param, omitempty); err != nil {
errors = append(errors, FieldValidationError{
Field: field.Name,
Field: displayName,
Tag: tagName,
Param: param,
Message: err.Message,
@@ -293,7 +321,7 @@ func validateField(fieldName string, fieldVal reflect.Value, tagName, param stri
func isEmptyValue(v reflect.Value) bool {
switch v.Kind() {
case reflect.String:
return v.String() == ""
return strings.TrimSpace(v.String()) == ""
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
@@ -301,7 +329,7 @@ func isEmptyValue(v reflect.Value) bool {
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.Bool:
return !v.Bool()
return false
case reflect.Pointer, reflect.Interface, reflect.Slice, reflect.Map:
return v.IsNil()
default:
@@ -317,7 +345,8 @@ func validateMin(fieldName string, v reflect.Value, param string) *ValidationErr
switch v.Kind() {
case reflect.String:
if len(v.String()) < min {
s := strings.TrimSpace(v.String())
if len([]rune(s)) < min {
return &ValidationError{Field: fieldName, Message: fmt.Sprintf("%s must be at least %d characters", fieldName, min)}
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
@@ -344,7 +373,7 @@ func validateMax(fieldName string, v reflect.Value, param string) *ValidationErr
switch v.Kind() {
case reflect.String:
if len(v.String()) > max {
if len([]rune(v.String())) > max {
return &ValidationError{Field: fieldName, Message: fmt.Sprintf("%s must be at most %d characters", fieldName, max)}
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:

View File

@@ -250,12 +250,12 @@ func TestSanitizeString(t *testing.T) {
func TestValidateStruct(t *testing.T) {
type TestStruct struct {
Username string `validate:"required,min=3,max=20"`
Email string `validate:"required,email"`
Age int `validate:"min=18,max=120"`
URL string `validate:"url"`
Status string `validate:"oneof=active inactive pending"`
Optional string `validate:"omitempty,min=1"`
Username string `json:"username" validate:"required,min=3,max=20"`
Email string `json:"email" validate:"required,email"`
Age int `json:"age" validate:"min=18,max=120"`
URL string `json:"url" validate:"url"`
Status string `json:"status" validate:"oneof=active inactive pending"`
Optional string `json:"optional" validate:"omitempty,min=1"`
}
t.Run("valid struct", func(t *testing.T) {
@@ -287,6 +287,9 @@ func TestValidateStruct(t *testing.T) {
if len(structErr.Errors) == 0 {
t.Error("Expected validation errors, got none")
}
if structErr.Errors[0].Message != "username is required" {
t.Errorf("Expected JSON tag name in error, got %q", structErr.Errors[0].Message)
}
}
})
@@ -318,6 +321,20 @@ func TestValidateStruct(t *testing.T) {
}
})
t.Run("whitespace required field", func(t *testing.T) {
s := TestStruct{
Username: " ",
Email: "test@example.com",
Age: 25,
URL: "https://example.com",
Status: "active",
}
err := ValidateStruct(s)
if err == nil {
t.Error("ValidateStruct() expected error, got nil")
}
})
t.Run("invalid max", func(t *testing.T) {
s := TestStruct{
Username: strings.Repeat("a", 21),

View File

@@ -1,3 +1,7 @@
package version
const Version = "0.1.0"
const version = "0.1.1"
func GetVersion() string {
return version
}

View File

@@ -8,39 +8,39 @@ import (
func TestVersionSemver(t *testing.T) {
semverRegex := regexp.MustCompile(`^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$`)
if !semverRegex.MatchString(Version) {
t.Errorf("Version %q does not follow semantic versioning format (MAJOR.MINOR.PATCH[-PRERELEASE][+BUILD])", Version)
if !semverRegex.MatchString(GetVersion()) {
t.Errorf("Version %q does not follow semantic versioning format (MAJOR.MINOR.PATCH[-PRERELEASE][+BUILD])", GetVersion())
}
}
func TestVersionSemverFlexible(t *testing.T) {
flexibleSemverRegex := regexp.MustCompile(`^(0|[1-9]\d*)\.(0|[1-9]\d*)(?:\.(0|[1-9]\d*))?(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$`)
if !flexibleSemverRegex.MatchString(Version) {
t.Errorf("Version %q does not follow semantic versioning format (MAJOR.MINOR[.PATCH][-PRERELEASE][+BUILD])", Version)
if !flexibleSemverRegex.MatchString(GetVersion()) {
t.Errorf("Version %q does not follow semantic versioning format (MAJOR.MINOR[.PATCH][-PRERELEASE][+BUILD])", GetVersion())
}
}
func TestVersionNotEmpty(t *testing.T) {
if Version == "" {
if GetVersion() == "" {
t.Error("Version should not be empty")
}
}
func TestVersionFormat(t *testing.T) {
if !regexp.MustCompile(`\d+\.\d+`).MatchString(Version) {
t.Errorf("Version %q should contain at least MAJOR.MINOR format", Version)
if !regexp.MustCompile(`\d+\.\d+`).MatchString(GetVersion()) {
t.Errorf("Version %q should contain at least MAJOR.MINOR format", GetVersion())
}
}
func TestVersionStartsWithNumber(t *testing.T) {
if !regexp.MustCompile(`^\d+`).MatchString(Version) {
t.Errorf("Version %q should start with a number", Version)
if !regexp.MustCompile(`^\d+`).MatchString(GetVersion()) {
t.Errorf("Version %q should start with a number", GetVersion())
}
}
func TestVersionNoLeadingZeros(t *testing.T) {
parts := regexp.MustCompile(`^(\d+)\.(\d+)`).FindStringSubmatch(Version)
parts := regexp.MustCompile(`^(\d+)\.(\d+)`).FindStringSubmatch(GetVersion())
if len(parts) >= 3 {
major := parts[1]
minor := parts[2]