Compare commits
90 Commits
75a33994db
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 65109a787c | |||
| 75f1406edf | |||
| 11dc9b507f | |||
| da616438e9 | |||
| 7486865343 | |||
| fd0fd8954a | |||
| 628db14f59 | |||
| 7be196e4c3 | |||
| 2f4bd45efb | |||
| 1b53c2b66b | |||
| 509e68f538 | |||
| e6a44d830e | |||
| fe396b7537 | |||
| 6eb04aa3c5 | |||
| 517d4482c9 | |||
| b6e2bf942a | |||
| 9f1058ba81 | |||
| 2bdbb29ae6 | |||
| 9d243a0ed1 | |||
| 9c74828b8d | |||
| 9e78477eb5 | |||
| a74980caa1 | |||
| 816f08a20a | |||
| 0cec152486 | |||
| 5413737491 | |||
| 5f605e45c7 | |||
| e5779183ff | |||
| 4814b64c2c | |||
| 45cad505d6 | |||
| 7f52347854 | |||
| 542913cbef | |||
| 2f964b0c79 | |||
| 250ff79eeb | |||
| 4dfe260953 | |||
| 49e6bb1e9d | |||
| 5b0c6018c0 | |||
| 3303d13f15 | |||
| c1746eb346 | |||
| e2804ca07e | |||
| 6cdad79caa | |||
| 6227b64746 | |||
| 506e233347 | |||
| 8c06c916e1 | |||
| 29fcaab25d | |||
| 422ff2473e | |||
| dbe4879457 | |||
| 5a530b7609 | |||
| 66b4b0e173 | |||
| e08e2b3189 | |||
| f39dcff67d | |||
| 08d8d0ed22 | |||
| 932c042aa2 | |||
| a1466e860d | |||
| a1e63b868f | |||
| b6f5293c0f | |||
| 6643466d76 | |||
| 9dcf748474 | |||
| 1ff1c8faf4 | |||
| 0bcc1eb427 | |||
| cbfe0fd54c | |||
| 1727ae4a7c | |||
| ef4a05f8a5 | |||
| 00ef0c236e | |||
| 2d58c15031 | |||
| 523dac242e | |||
| 53da1eee2a | |||
| 20ea6c4a27 | |||
| 0e557c3f89 | |||
| 56770955d4 | |||
| 34fbc2f8b1 | |||
| 05e69c7f36 | |||
| 9ff7c98cf0 | |||
| 893ee154de | |||
| dfee90504a | |||
| bc0c9e5fea | |||
| 35ef42eb93 | |||
| 9ceaf35fd9 | |||
| 395cc299f3 | |||
| 058c69b414 | |||
| d744aa8393 | |||
| 44e2f97cb7 | |||
| 4888916613 | |||
| 3ca2334932 | |||
| 02d0c3f946 | |||
| 1b55c9543e | |||
| 73930dabd8 | |||
| 19aadc6fc8 | |||
| 8bdff51eed | |||
| ac2dfdde70 | |||
| c3d0d16e44 |
@@ -44,7 +44,6 @@ linters:
|
|||||||
|
|
||||||
gocritic:
|
gocritic:
|
||||||
disabled-checks:
|
disabled-checks:
|
||||||
- hugeParam
|
|
||||||
- ifElseChain
|
- ifElseChain
|
||||||
settings:
|
settings:
|
||||||
captLocal:
|
captLocal:
|
||||||
|
|||||||
46
README.md
46
README.md
@@ -117,6 +117,8 @@ JWT_EXPIRATION=1
|
|||||||
JWT_REFRESH_EXPIRATION=168
|
JWT_REFRESH_EXPIRATION=168
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Refresh tokens rotate on each successful refresh, the previous refresh token is invalidated.
|
||||||
|
|
||||||
### SMTP Configuration
|
### SMTP Configuration
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -201,22 +203,44 @@ It'll be more readable and easier to parse.
|
|||||||
|
|
||||||
- `POST /api/auth/register` - Register new user
|
- `POST /api/auth/register` - Register new user
|
||||||
- `POST /api/auth/login` - Login user
|
- `POST /api/auth/login` - Login user
|
||||||
- `GET /api/auth/confirm` - Confirm email
|
- `GET /api/auth/confirm` - Confirm email address
|
||||||
- `POST /api/auth/logout` - Logout user
|
- `POST /api/auth/resend-verification` - Resend verification email
|
||||||
|
- `POST /api/auth/forgot-password` - Request password reset
|
||||||
|
- `POST /api/auth/reset-password` - Reset password
|
||||||
|
- `GET /api/auth/me` - Get current user profile (protected)
|
||||||
|
- `POST /api/auth/logout` - Logout user (protected)
|
||||||
|
- `POST /api/auth/refresh` - Refresh access token (rotates refresh token)
|
||||||
|
- `POST /api/auth/revoke` - Revoke a refresh token (protected)
|
||||||
|
- `POST /api/auth/revoke-all` - Revoke all refresh tokens for the current user (protected)
|
||||||
|
- `PUT /api/auth/email` - Update email address (protected)
|
||||||
|
- `PUT /api/auth/username` - Update username (protected)
|
||||||
|
- `PUT /api/auth/password` - Update password (protected)
|
||||||
|
- `DELETE /api/auth/account` - Request account deletion (protected)
|
||||||
|
- `POST /api/auth/account/confirm` - Confirm account deletion
|
||||||
|
|
||||||
#### Posts
|
#### Posts
|
||||||
|
|
||||||
- `GET /api/posts` - List posts
|
- `GET /api/posts` - List posts
|
||||||
- `POST /api/posts` - Create post
|
- `GET /api/posts/search` - Search posts
|
||||||
|
- `GET /api/posts/title` - Fetch title from URL
|
||||||
- `GET /api/posts/{id}` - Get specific post
|
- `GET /api/posts/{id}` - Get specific post
|
||||||
- `PUT /api/posts/{id}` - Update post
|
- `POST /api/posts` - Create post (protected)
|
||||||
- `DELETE /api/posts/{id}` - Delete post
|
- `PUT /api/posts/{id}` - Update post (protected)
|
||||||
|
- `DELETE /api/posts/{id}` - Delete post (protected)
|
||||||
|
|
||||||
#### Voting
|
#### Voting
|
||||||
|
|
||||||
- `POST /api/posts/{id}/vote` - Cast vote
|
- `POST /api/posts/{id}/vote` - Cast vote (protected)
|
||||||
- `DELETE /api/posts/{id}/vote` - Remove vote
|
- `DELETE /api/posts/{id}/vote` - Remove vote (protected)
|
||||||
- `GET /api/posts/{id}/votes` - Get post votes
|
- `GET /api/posts/{id}/vote` - Get current user's vote (protected)
|
||||||
|
- `GET /api/posts/{id}/votes` - Get all votes for post
|
||||||
|
|
||||||
|
#### Users
|
||||||
|
|
||||||
|
- `GET /api/users` - List all users (protected)
|
||||||
|
- `POST /api/users` - Create new user (protected)
|
||||||
|
- `GET /api/users/{id}` - Get specific user
|
||||||
|
- `GET /api/users/{id}/posts` - Get user's posts (protected)
|
||||||
|
|
||||||
## CLI Commands
|
## CLI Commands
|
||||||
|
|
||||||
@@ -395,14 +419,10 @@ This will regenerate the swagger documentation and update the `docs/swagger.json
|
|||||||
|
|
||||||
## Roadmap
|
## Roadmap
|
||||||
|
|
||||||
- [ ] migrate cli to urfave/cli
|
- [x] migrate cli to urfave/cli
|
||||||
- [ ] add a ML powered nsfw link detection
|
- [ ] add a ML powered nsfw link detection
|
||||||
- [ ] add right management within the app
|
|
||||||
- [ ] add an admin backoffice to manage rights, users, content and settings
|
|
||||||
- [ ] add a way to run read-only communities
|
- [ ] add a way to run read-only communities
|
||||||
- [ ] migrate raw CSS to UnoCSS
|
- [ ] migrate raw CSS to UnoCSS
|
||||||
- [ ] kubernetes deployment
|
|
||||||
- [ ] store configuration in the database
|
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
|
|||||||
132
cmd/goyco/cli.go
132
cmd/goyco/cli.go
@@ -1,14 +1,24 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"goyco/cmd/goyco/commands"
|
"goyco/cmd/goyco/commands"
|
||||||
|
"goyco/internal/config"
|
||||||
|
|
||||||
"github.com/joho/godotenv"
|
"github.com/joho/godotenv"
|
||||||
|
"github.com/urfave/cli/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
helpPrinterOnce sync.Once
|
||||||
|
defaultHelpPrinter func(io.Writer, string, interface{})
|
||||||
)
|
)
|
||||||
|
|
||||||
func loadDotEnv() {
|
func loadDotEnv() {
|
||||||
@@ -55,3 +65,125 @@ func printRunUsage() {
|
|||||||
fmt.Fprintln(os.Stderr, "Usage: goyco run")
|
fmt.Fprintln(os.Stderr, "Usage: goyco run")
|
||||||
fmt.Fprintln(os.Stderr, "\nStart the web application in foreground.")
|
fmt.Fprintln(os.Stderr, "\nStart the web application in foreground.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildRootCommand(cfg *config.Config) *cli.Command {
|
||||||
|
helpPrinterOnce.Do(func() {
|
||||||
|
defaultHelpPrinter = cli.HelpPrinter
|
||||||
|
})
|
||||||
|
cli.HelpPrinter = func(w io.Writer, templ string, data interface{}) {
|
||||||
|
if cmd, ok := data.(*cli.Command); ok && cmd.Root() == cmd {
|
||||||
|
printRootUsage()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defaultHelpPrinter(w, templ, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
root := &cli.Command{
|
||||||
|
Name: "goyco",
|
||||||
|
Usage: "Y Combinator-style news aggregation platform API",
|
||||||
|
UsageText: "goyco <command> [<args>]",
|
||||||
|
HideVersion: true,
|
||||||
|
Flags: []cli.Flag{
|
||||||
|
&cli.BoolFlag{
|
||||||
|
Name: "json",
|
||||||
|
Usage: "output results in JSON format",
|
||||||
|
Value: cfg.CLI.JSONOutputDefault,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Before: func(ctx context.Context, cmd *cli.Command) (context.Context, error) {
|
||||||
|
commands.SetJSONOutput(cmd.Bool("json"))
|
||||||
|
return ctx, nil
|
||||||
|
},
|
||||||
|
After: func(ctx context.Context, cmd *cli.Command) error {
|
||||||
|
cli.HelpPrinter = defaultHelpPrinter
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
Action: func(_ context.Context, cmd *cli.Command) error {
|
||||||
|
if cmd.NArg() == 0 {
|
||||||
|
printRootUsage()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
printRootUsage()
|
||||||
|
return fmt.Errorf("unknown command: %s", cmd.Args().First())
|
||||||
|
},
|
||||||
|
Commands: []*cli.Command{
|
||||||
|
{
|
||||||
|
Name: "run",
|
||||||
|
Usage: "start the web application in foreground",
|
||||||
|
SkipFlagParsing: true,
|
||||||
|
Action: func(_ context.Context, cmd *cli.Command) error {
|
||||||
|
return handleRunCommand(cfg, cmd.Args().Slice())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "start",
|
||||||
|
Usage: "start the web application in background",
|
||||||
|
SkipFlagParsing: true,
|
||||||
|
Action: func(_ context.Context, cmd *cli.Command) error {
|
||||||
|
return commands.HandleStartCommand(cfg, cmd.Args().Slice())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "stop",
|
||||||
|
Usage: "stop the daemon",
|
||||||
|
SkipFlagParsing: true,
|
||||||
|
Action: func(_ context.Context, cmd *cli.Command) error {
|
||||||
|
return commands.HandleStopCommand(cfg, cmd.Args().Slice())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "status",
|
||||||
|
Usage: "check if the daemon is running",
|
||||||
|
SkipFlagParsing: true,
|
||||||
|
Action: func(_ context.Context, cmd *cli.Command) error {
|
||||||
|
return commands.HandleStatusCommand(cfg, cmd.Name, cmd.Args().Slice())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "migrate",
|
||||||
|
Aliases: []string{"migrations"},
|
||||||
|
Usage: "apply database migrations",
|
||||||
|
SkipFlagParsing: true,
|
||||||
|
Action: func(_ context.Context, cmd *cli.Command) error {
|
||||||
|
return commands.HandleMigrateCommand(cfg, cmd.Name, cmd.Args().Slice())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "user",
|
||||||
|
Usage: "manage users (create, update, delete, lock, list)",
|
||||||
|
SkipFlagParsing: true,
|
||||||
|
Action: func(_ context.Context, cmd *cli.Command) error {
|
||||||
|
return commands.HandleUserCommand(cfg, cmd.Name, cmd.Args().Slice())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "post",
|
||||||
|
Usage: "manage posts (delete, list, search)",
|
||||||
|
SkipFlagParsing: true,
|
||||||
|
Action: func(_ context.Context, cmd *cli.Command) error {
|
||||||
|
return commands.HandlePostCommand(cfg, cmd.Name, cmd.Args().Slice())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "prune",
|
||||||
|
Usage: "hard delete users and posts (posts, all)",
|
||||||
|
SkipFlagParsing: true,
|
||||||
|
Action: func(_ context.Context, cmd *cli.Command) error {
|
||||||
|
return commands.HandlePruneCommand(cfg, cmd.Name, cmd.Args().Slice())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "seed",
|
||||||
|
Usage: "seed database with random data",
|
||||||
|
SkipFlagParsing: true,
|
||||||
|
Action: func(_ context.Context, cmd *cli.Command) error {
|
||||||
|
return commands.HandleSeedCommand(cfg, cmd.Name, cmd.Args().Slice())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Writer: os.Stdout,
|
||||||
|
ErrWriter: os.Stderr,
|
||||||
|
}
|
||||||
|
|
||||||
|
return root
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"os"
|
"os"
|
||||||
@@ -131,25 +132,26 @@ func TestPrintRunUsage(t *testing.T) {
|
|||||||
printRunUsage()
|
printRunUsage()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDispatchCommand(t *testing.T) {
|
func TestRootCommandDispatch(t *testing.T) {
|
||||||
|
|
||||||
t.Run("unknown command", func(t *testing.T) {
|
t.Run("unknown command", func(t *testing.T) {
|
||||||
cfg := testutils.NewTestConfig()
|
cfg := testutils.NewTestConfig()
|
||||||
err := dispatchCommand(cfg, "unknown", []string{})
|
cmd := buildRootCommand(cfg)
|
||||||
|
err := cmd.Run(context.Background(), []string{"goyco", "unknown"})
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error for unknown command")
|
t.Error("expected error for unknown command")
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedErr := "unknown command: unknown"
|
expectedErr := "unknown command: unknown"
|
||||||
if err.Error() != expectedErr {
|
if err != nil && err.Error() != expectedErr {
|
||||||
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
|
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("help command", func(t *testing.T) {
|
t.Run("help command", func(t *testing.T) {
|
||||||
cfg := testutils.NewTestConfig()
|
cfg := testutils.NewTestConfig()
|
||||||
err := dispatchCommand(cfg, "help", []string{})
|
cmd := buildRootCommand(cfg)
|
||||||
|
err := cmd.Run(context.Background(), []string{"goyco", "help"})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unexpected error for help command: %v", err)
|
t.Errorf("unexpected error for help command: %v", err)
|
||||||
@@ -158,7 +160,8 @@ func TestDispatchCommand(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("h command", func(t *testing.T) {
|
t.Run("h command", func(t *testing.T) {
|
||||||
cfg := testutils.NewTestConfig()
|
cfg := testutils.NewTestConfig()
|
||||||
err := dispatchCommand(cfg, "-h", []string{})
|
cmd := buildRootCommand(cfg)
|
||||||
|
err := cmd.Run(context.Background(), []string{"goyco", "-h"})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unexpected error for -h command: %v", err)
|
t.Errorf("unexpected error for -h command: %v", err)
|
||||||
@@ -167,7 +170,8 @@ func TestDispatchCommand(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("--help command", func(t *testing.T) {
|
t.Run("--help command", func(t *testing.T) {
|
||||||
cfg := testutils.NewTestConfig()
|
cfg := testutils.NewTestConfig()
|
||||||
err := dispatchCommand(cfg, "--help", []string{})
|
cmd := buildRootCommand(cfg)
|
||||||
|
err := cmd.Run(context.Background(), []string{"goyco", "--help"})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unexpected error for --help command: %v", err)
|
t.Errorf("unexpected error for --help command: %v", err)
|
||||||
@@ -179,7 +183,8 @@ func TestDispatchCommand(t *testing.T) {
|
|||||||
|
|
||||||
useInMemoryCommandsConnector(t)
|
useInMemoryCommandsConnector(t)
|
||||||
|
|
||||||
err := dispatchCommand(cfg, "post", []string{"list"})
|
cmd := buildRootCommand(cfg)
|
||||||
|
err := cmd.Run(context.Background(), []string{"goyco", "post", "list"})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unexpected error for post list: %v", err)
|
t.Errorf("unexpected error for post list: %v", err)
|
||||||
|
|||||||
@@ -1,533 +0,0 @@
|
|||||||
package commands
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
cryptoRand "crypto/rand"
|
|
||||||
"fmt"
|
|
||||||
"math/rand"
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"goyco/internal/database"
|
|
||||||
"goyco/internal/repositories"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ParallelProcessor struct {
|
|
||||||
maxWorkers int
|
|
||||||
timeout time.Duration
|
|
||||||
passwordHash string
|
|
||||||
randSource *rand.Rand
|
|
||||||
randMu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewParallelProcessor() *ParallelProcessor {
|
|
||||||
maxWorkers := max(min(runtime.NumCPU(), 8), 2)
|
|
||||||
|
|
||||||
seed := time.Now().UnixNano()
|
|
||||||
seedBytes := make([]byte, 8)
|
|
||||||
if _, err := cryptoRand.Read(seedBytes); err == nil {
|
|
||||||
seed = int64(seedBytes[0])<<56 | int64(seedBytes[1])<<48 | int64(seedBytes[2])<<40 | int64(seedBytes[3])<<32 |
|
|
||||||
int64(seedBytes[4])<<24 | int64(seedBytes[5])<<16 | int64(seedBytes[6])<<8 | int64(seedBytes[7])
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ParallelProcessor{
|
|
||||||
maxWorkers: maxWorkers,
|
|
||||||
timeout: 60 * time.Second,
|
|
||||||
randSource: rand.New(rand.NewSource(seed)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ParallelProcessor) SetPasswordHash(hash string) {
|
|
||||||
p.passwordHash = hash
|
|
||||||
}
|
|
||||||
|
|
||||||
type indexedResult[T any] struct {
|
|
||||||
value T
|
|
||||||
index int
|
|
||||||
}
|
|
||||||
|
|
||||||
func processInParallel[T any](
|
|
||||||
ctx context.Context,
|
|
||||||
maxWorkers int,
|
|
||||||
count int,
|
|
||||||
processor func(index int) (T, error),
|
|
||||||
errorPrefix string,
|
|
||||||
progress *ProgressIndicator,
|
|
||||||
) ([]T, error) {
|
|
||||||
results := make(chan indexedResult[T], count)
|
|
||||||
errors := make(chan error, count)
|
|
||||||
|
|
||||||
semaphore := make(chan struct{}, maxWorkers)
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
|
|
||||||
for i := range count {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(index int) {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case semaphore <- struct{}{}:
|
|
||||||
case <-ctx.Done():
|
|
||||||
errors <- ctx.Err()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() { <-semaphore }()
|
|
||||||
|
|
||||||
value, err := processor(index + 1)
|
|
||||||
if err != nil {
|
|
||||||
errors <- fmt.Errorf("%s %d: %w", errorPrefix, index+1, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
results <- indexedResult[T]{value: value, index: index}
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
wg.Wait()
|
|
||||||
close(results)
|
|
||||||
close(errors)
|
|
||||||
}()
|
|
||||||
|
|
||||||
items := make([]T, count)
|
|
||||||
completed := 0
|
|
||||||
firstError := make(chan error, 1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for err := range errors {
|
|
||||||
if err != nil {
|
|
||||||
select {
|
|
||||||
case firstError <- err:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for completed < count {
|
|
||||||
select {
|
|
||||||
case result, ok := <-results:
|
|
||||||
if !ok {
|
|
||||||
return items, nil
|
|
||||||
}
|
|
||||||
items[result.index] = result.value
|
|
||||||
completed++
|
|
||||||
if progress != nil {
|
|
||||||
progress.Update(completed)
|
|
||||||
}
|
|
||||||
case err := <-firstError:
|
|
||||||
return nil, err
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, fmt.Errorf("timeout: %w", ctx.Err())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return items, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepository, count int, progress *ProgressIndicator) ([]database.User, error) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
return processInParallel(ctx, p.maxWorkers, count,
|
|
||||||
func(index int) (database.User, error) {
|
|
||||||
return p.createSingleUser(userRepo, index)
|
|
||||||
},
|
|
||||||
"create user",
|
|
||||||
progress,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepository, authorID uint, count int, progress *ProgressIndicator) ([]database.Post, error) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
return processInParallel(ctx, p.maxWorkers, count,
|
|
||||||
func(index int) (database.Post, error) {
|
|
||||||
return p.createSinglePost(postRepo, authorID, index)
|
|
||||||
},
|
|
||||||
"create post",
|
|
||||||
progress,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func processItemsInParallel[T any, R any](
|
|
||||||
ctx context.Context,
|
|
||||||
maxWorkers int,
|
|
||||||
items []T,
|
|
||||||
processor func(index int, item T) (R, error),
|
|
||||||
errorPrefix string,
|
|
||||||
aggregator func(accumulator R, value R) R,
|
|
||||||
initialValue R,
|
|
||||||
progress *ProgressIndicator,
|
|
||||||
) (R, error) {
|
|
||||||
count := len(items)
|
|
||||||
results := make(chan indexedResult[R], count)
|
|
||||||
errors := make(chan error, count)
|
|
||||||
|
|
||||||
semaphore := make(chan struct{}, maxWorkers)
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
|
|
||||||
for i, item := range items {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(index int, item T) {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case semaphore <- struct{}{}:
|
|
||||||
case <-ctx.Done():
|
|
||||||
errors <- ctx.Err()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() { <-semaphore }()
|
|
||||||
|
|
||||||
value, err := processor(index, item)
|
|
||||||
if err != nil {
|
|
||||||
errors <- fmt.Errorf("%s %d: %w", errorPrefix, index+1, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
results <- indexedResult[R]{value: value, index: index}
|
|
||||||
}(i, item)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
wg.Wait()
|
|
||||||
close(results)
|
|
||||||
close(errors)
|
|
||||||
}()
|
|
||||||
|
|
||||||
accumulator := initialValue
|
|
||||||
completed := 0
|
|
||||||
firstError := make(chan error, 1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for err := range errors {
|
|
||||||
if err != nil {
|
|
||||||
select {
|
|
||||||
case firstError <- err:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for completed < count {
|
|
||||||
select {
|
|
||||||
case result, ok := <-results:
|
|
||||||
if !ok {
|
|
||||||
return accumulator, nil
|
|
||||||
}
|
|
||||||
accumulator = aggregator(accumulator, result.value)
|
|
||||||
completed++
|
|
||||||
if progress != nil {
|
|
||||||
progress.Update(completed)
|
|
||||||
}
|
|
||||||
case err := <-firstError:
|
|
||||||
return initialValue, err
|
|
||||||
case <-ctx.Done():
|
|
||||||
return initialValue, fmt.Errorf("timeout: %w", ctx.Err())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return accumulator, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post, avgVotesPerPost int, progress *ProgressIndicator) (int, error) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
return processItemsInParallel(ctx, p.maxWorkers, posts,
|
|
||||||
func(index int, post database.Post) (int, error) {
|
|
||||||
return p.createVotesForPost(voteRepo, users, post, avgVotesPerPost)
|
|
||||||
},
|
|
||||||
"create votes for post",
|
|
||||||
func(acc, val int) int { return acc + val },
|
|
||||||
0,
|
|
||||||
progress,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func processItemsInParallelNoResult[T any](
|
|
||||||
ctx context.Context,
|
|
||||||
maxWorkers int,
|
|
||||||
items []T,
|
|
||||||
processor func(index int, item T) error,
|
|
||||||
errorFormatter func(index int, item T, err error) error,
|
|
||||||
progress *ProgressIndicator,
|
|
||||||
) error {
|
|
||||||
count := len(items)
|
|
||||||
errors := make(chan error, count)
|
|
||||||
completions := make(chan struct{}, count)
|
|
||||||
|
|
||||||
semaphore := make(chan struct{}, maxWorkers)
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
|
|
||||||
for i, item := range items {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(index int, item T) {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case semaphore <- struct{}{}:
|
|
||||||
case <-ctx.Done():
|
|
||||||
errors <- ctx.Err()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() { <-semaphore }()
|
|
||||||
|
|
||||||
err := processor(index, item)
|
|
||||||
if err != nil {
|
|
||||||
if errorFormatter != nil {
|
|
||||||
errors <- errorFormatter(index, item, err)
|
|
||||||
} else {
|
|
||||||
errors <- fmt.Errorf("process item %d: %w", index+1, err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
completions <- struct{}{}
|
|
||||||
}(i, item)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
wg.Wait()
|
|
||||||
close(errors)
|
|
||||||
close(completions)
|
|
||||||
}()
|
|
||||||
|
|
||||||
completed := 0
|
|
||||||
firstError := make(chan error, 1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for err := range errors {
|
|
||||||
if err != nil {
|
|
||||||
select {
|
|
||||||
case firstError <- err:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for completed < count {
|
|
||||||
select {
|
|
||||||
case _, ok := <-completions:
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
completed++
|
|
||||||
if progress != nil {
|
|
||||||
progress.Update(completed)
|
|
||||||
}
|
|
||||||
case err := <-firstError:
|
|
||||||
return err
|
|
||||||
case <-ctx.Done():
|
|
||||||
return fmt.Errorf("timeout: %w", ctx.Err())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, progress *ProgressIndicator) error {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
return processItemsInParallelNoResult(ctx, p.maxWorkers, posts,
|
|
||||||
func(index int, post database.Post) error {
|
|
||||||
return p.updateSinglePostScore(postRepo, voteRepo, post)
|
|
||||||
},
|
|
||||||
func(index int, post database.Post, err error) error {
|
|
||||||
return fmt.Errorf("update post %d scores: %w", post.ID, err)
|
|
||||||
},
|
|
||||||
progress,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ParallelProcessor) generateRandomIdentifier() string {
|
|
||||||
const length = 12
|
|
||||||
const chars = "abcdefghijklmnopqrstuvwxyz0123456789"
|
|
||||||
identifier := make([]byte, length)
|
|
||||||
p.randMu.Lock()
|
|
||||||
for i := range identifier {
|
|
||||||
identifier[i] = chars[p.randSource.Intn(len(chars))]
|
|
||||||
}
|
|
||||||
p.randMu.Unlock()
|
|
||||||
return string(identifier)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ParallelProcessor) createSingleUser(userRepo repositories.UserRepository, index int) (database.User, error) {
|
|
||||||
randomID := p.generateRandomIdentifier()
|
|
||||||
username := fmt.Sprintf("user_%s", randomID)
|
|
||||||
email := fmt.Sprintf("user_%s@goyco.local", randomID)
|
|
||||||
|
|
||||||
const maxRetries = 10
|
|
||||||
for range maxRetries {
|
|
||||||
user := &database.User{
|
|
||||||
Username: username,
|
|
||||||
Email: email,
|
|
||||||
Password: p.passwordHash,
|
|
||||||
EmailVerified: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := userRepo.Create(user); err != nil {
|
|
||||||
randomID = p.generateRandomIdentifier()
|
|
||||||
username = fmt.Sprintf("user_%s", randomID)
|
|
||||||
email = fmt.Sprintf("user_%s@goyco.local", randomID)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
return *user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return database.User{}, fmt.Errorf("failed to create user after %d attempts", maxRetries)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ParallelProcessor) createSinglePost(postRepo repositories.PostRepository, authorID uint, index int) (database.Post, error) {
|
|
||||||
sampleTitles := []string{
|
|
||||||
"Amazing JavaScript Framework",
|
|
||||||
"Python Best Practices",
|
|
||||||
"Go Performance Tips",
|
|
||||||
"Database Optimization",
|
|
||||||
"Web Security Guide",
|
|
||||||
"Machine Learning Basics",
|
|
||||||
"Cloud Architecture",
|
|
||||||
"DevOps Automation",
|
|
||||||
"API Design Patterns",
|
|
||||||
"Frontend Optimization",
|
|
||||||
"Backend Scaling",
|
|
||||||
"Container Orchestration",
|
|
||||||
"Microservices Architecture",
|
|
||||||
"Testing Strategies",
|
|
||||||
"Code Review Process",
|
|
||||||
"Version Control Best Practices",
|
|
||||||
"Continuous Integration",
|
|
||||||
"Monitoring and Alerting",
|
|
||||||
"Error Handling Patterns",
|
|
||||||
"Data Structures Explained",
|
|
||||||
}
|
|
||||||
|
|
||||||
sampleDomains := []string{
|
|
||||||
"example.com",
|
|
||||||
"techblog.org",
|
|
||||||
"devguide.net",
|
|
||||||
"programming.io",
|
|
||||||
"codeexamples.com",
|
|
||||||
"tutorialhub.org",
|
|
||||||
"bestpractices.dev",
|
|
||||||
"learnprogramming.net",
|
|
||||||
"codingtips.org",
|
|
||||||
"softwareengineering.com",
|
|
||||||
}
|
|
||||||
|
|
||||||
title := sampleTitles[index%len(sampleTitles)]
|
|
||||||
if index >= len(sampleTitles) {
|
|
||||||
title = fmt.Sprintf("%s - Part %d", title, (index/len(sampleTitles))+1)
|
|
||||||
}
|
|
||||||
|
|
||||||
domain := sampleDomains[index%len(sampleDomains)]
|
|
||||||
randomID := p.generateRandomIdentifier()
|
|
||||||
path := fmt.Sprintf("/article/%s", randomID)
|
|
||||||
url := fmt.Sprintf("https://%s%s", domain, path)
|
|
||||||
|
|
||||||
content := fmt.Sprintf("Autogenerated seed post #%d\n\nThis is sample content for testing purposes. The post discusses %s and provides valuable insights.", index, title)
|
|
||||||
|
|
||||||
const maxRetries = 10
|
|
||||||
for range maxRetries {
|
|
||||||
post := &database.Post{
|
|
||||||
Title: title,
|
|
||||||
URL: url,
|
|
||||||
Content: content,
|
|
||||||
AuthorID: &authorID,
|
|
||||||
UpVotes: 0,
|
|
||||||
DownVotes: 0,
|
|
||||||
Score: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := postRepo.Create(post); err != nil {
|
|
||||||
randomID = p.generateRandomIdentifier()
|
|
||||||
path = fmt.Sprintf("/article/%s", randomID)
|
|
||||||
url = fmt.Sprintf("https://%s%s", domain, path)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
return *post, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return database.Post{}, fmt.Errorf("failed to create post after %d attempts", maxRetries)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ParallelProcessor) createVotesForPost(voteRepo repositories.VoteRepository, users []database.User, post database.Post, avgVotesPerPost int) (int, error) {
|
|
||||||
p.randMu.Lock()
|
|
||||||
numVotes := p.randSource.Intn(avgVotesPerPost*2 + 1)
|
|
||||||
p.randMu.Unlock()
|
|
||||||
|
|
||||||
if numVotes == 0 && avgVotesPerPost > 0 {
|
|
||||||
p.randMu.Lock()
|
|
||||||
if p.randSource.Intn(5) > 0 {
|
|
||||||
numVotes = 1
|
|
||||||
}
|
|
||||||
p.randMu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
totalVotes := 0
|
|
||||||
usedUsers := make(map[uint]bool)
|
|
||||||
|
|
||||||
for i := 0; i < numVotes && len(usedUsers) < len(users); i++ {
|
|
||||||
p.randMu.Lock()
|
|
||||||
userIdx := p.randSource.Intn(len(users))
|
|
||||||
p.randMu.Unlock()
|
|
||||||
user := users[userIdx]
|
|
||||||
|
|
||||||
if usedUsers[user.ID] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
usedUsers[user.ID] = true
|
|
||||||
|
|
||||||
p.randMu.Lock()
|
|
||||||
voteTypeInt := p.randSource.Intn(10)
|
|
||||||
p.randMu.Unlock()
|
|
||||||
var voteType database.VoteType
|
|
||||||
if voteTypeInt < 7 {
|
|
||||||
voteType = database.VoteUp
|
|
||||||
} else {
|
|
||||||
voteType = database.VoteDown
|
|
||||||
}
|
|
||||||
|
|
||||||
vote := &database.Vote{
|
|
||||||
UserID: &user.ID,
|
|
||||||
PostID: post.ID,
|
|
||||||
Type: voteType,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := voteRepo.CreateOrUpdate(vote); err != nil {
|
|
||||||
return totalVotes, fmt.Errorf("create or update vote: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
totalVotes++
|
|
||||||
}
|
|
||||||
|
|
||||||
return totalVotes, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ParallelProcessor) updateSinglePostScore(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, post database.Post) error {
|
|
||||||
upVotes, downVotes, err := getVoteCounts(voteRepo, post.ID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("get vote counts: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
post.UpVotes = upVotes
|
|
||||||
post.DownVotes = downVotes
|
|
||||||
post.Score = upVotes - downVotes
|
|
||||||
|
|
||||||
if err := postRepo.Update(&post); err != nil {
|
|
||||||
return fmt.Errorf("update post: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,145 +0,0 @@
|
|||||||
package commands_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"goyco/cmd/goyco/commands"
|
|
||||||
"goyco/internal/database"
|
|
||||||
"goyco/internal/repositories"
|
|
||||||
"goyco/internal/testutils"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
|
|
||||||
const successCount = 4
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
count int
|
|
||||||
repoFactory func() repositories.UserRepository
|
|
||||||
progress *commands.ProgressIndicator
|
|
||||||
validate func(t *testing.T, got []database.User)
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "creates users with required fields",
|
|
||||||
count: successCount,
|
|
||||||
repoFactory: func() repositories.UserRepository {
|
|
||||||
base := testutils.NewMockUserRepository()
|
|
||||||
return newFakeUserRepo(base, 0, nil)
|
|
||||||
},
|
|
||||||
progress: nil,
|
|
||||||
validate: func(t *testing.T, got []database.User) {
|
|
||||||
t.Helper()
|
|
||||||
if len(got) != successCount {
|
|
||||||
t.Fatalf("expected %d users, got %d", successCount, len(got))
|
|
||||||
}
|
|
||||||
usernames := make(map[string]bool)
|
|
||||||
for i, user := range got {
|
|
||||||
if user.Username == "" {
|
|
||||||
t.Errorf("user %d expected non-empty username", i)
|
|
||||||
}
|
|
||||||
if len(user.Username) < 6 || user.Username[:5] != "user_" {
|
|
||||||
t.Errorf("user %d username should start with 'user_', got %q", i, user.Username)
|
|
||||||
}
|
|
||||||
if usernames[user.Username] {
|
|
||||||
t.Errorf("user %d duplicate username: %q", i, user.Username)
|
|
||||||
}
|
|
||||||
usernames[user.Username] = true
|
|
||||||
|
|
||||||
if user.Email == "" {
|
|
||||||
t.Errorf("user %d expected non-empty email", i)
|
|
||||||
}
|
|
||||||
if len(user.Email) < 20 || user.Email[:5] != "user_" || user.Email[len(user.Email)-12:] != "@goyco.local" {
|
|
||||||
t.Errorf("user %d email should match pattern 'user_*@goyco.local', got %q", i, user.Email)
|
|
||||||
}
|
|
||||||
if !user.EmailVerified {
|
|
||||||
t.Errorf("user %d expected EmailVerified to be true", i)
|
|
||||||
}
|
|
||||||
if user.ID == 0 {
|
|
||||||
t.Errorf("user %d expected non-zero ID", i)
|
|
||||||
}
|
|
||||||
if user.Password == "" {
|
|
||||||
t.Errorf("user %d expected hashed password to be populated", i)
|
|
||||||
}
|
|
||||||
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte("password123")); err != nil {
|
|
||||||
t.Errorf("user %d password not hashed correctly: %v", i, err)
|
|
||||||
}
|
|
||||||
if user.CreatedAt.IsZero() {
|
|
||||||
t.Errorf("user %d expected CreatedAt to be set", i)
|
|
||||||
}
|
|
||||||
if user.UpdatedAt.IsZero() {
|
|
||||||
t.Errorf("user %d expected UpdatedAt to be set", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "returns error when repository create fails",
|
|
||||||
count: 3,
|
|
||||||
repoFactory: func() repositories.UserRepository {
|
|
||||||
base := testutils.NewMockUserRepository()
|
|
||||||
return newFakeUserRepo(base, 1, errors.New("create failure"))
|
|
||||||
},
|
|
||||||
progress: nil,
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
repo := tt.repoFactory()
|
|
||||||
p := commands.NewParallelProcessor()
|
|
||||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to generate password hash: %v", err)
|
|
||||||
}
|
|
||||||
p.SetPasswordHash(string(passwordHash))
|
|
||||||
got, gotErr := p.CreateUsersInParallel(repo, tt.count, tt.progress)
|
|
||||||
if gotErr != nil {
|
|
||||||
if !tt.wantErr {
|
|
||||||
t.Errorf("CreateUsersInParallel() failed: %v", gotErr)
|
|
||||||
}
|
|
||||||
if got != nil {
|
|
||||||
t.Error("expected nil result when error occurs")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if tt.wantErr {
|
|
||||||
t.Fatal("CreateUsersInParallel() succeeded unexpectedly")
|
|
||||||
}
|
|
||||||
if tt.validate != nil {
|
|
||||||
tt.validate(t, got)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type fakeUserRepo struct {
|
|
||||||
repositories.UserRepository
|
|
||||||
mu sync.Mutex
|
|
||||||
failAt int
|
|
||||||
err error
|
|
||||||
calls int
|
|
||||||
}
|
|
||||||
|
|
||||||
func newFakeUserRepo(base repositories.UserRepository, failAt int, err error) *fakeUserRepo {
|
|
||||||
return &fakeUserRepo{
|
|
||||||
UserRepository: base,
|
|
||||||
failAt: failAt,
|
|
||||||
err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *fakeUserRepo) Create(user *database.User) error {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
r.calls++
|
|
||||||
if r.failAt > 0 && r.calls >= r.failAt {
|
|
||||||
return r.err
|
|
||||||
}
|
|
||||||
return r.UserRepository.Create(user)
|
|
||||||
}
|
|
||||||
@@ -56,16 +56,16 @@ func newProgressIndicatorWithClock(total int, description string, c clock) *Prog
|
|||||||
|
|
||||||
func (p *ProgressIndicator) Update(current int) {
|
func (p *ProgressIndicator) Update(current int) {
|
||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
defer p.mu.Unlock()
|
|
||||||
|
|
||||||
p.current = current
|
p.current = current
|
||||||
now := p.clock.Now()
|
now := p.clock.Now()
|
||||||
|
|
||||||
if now.Sub(p.lastUpdate) < 100*time.Millisecond {
|
if now.Sub(p.lastUpdate) < 100*time.Millisecond {
|
||||||
|
p.mu.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.lastUpdate = now
|
p.lastUpdate = now
|
||||||
|
p.mu.Unlock()
|
||||||
p.display()
|
p.display()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -44,15 +44,14 @@ func captureOutput(fn func()) string {
|
|||||||
r, w, _ := os.Pipe()
|
r, w, _ := os.Pipe()
|
||||||
os.Stdout = w
|
os.Stdout = w
|
||||||
|
|
||||||
defer func() {
|
fn()
|
||||||
|
|
||||||
_ = w.Close()
|
_ = w.Close()
|
||||||
os.Stdout = old
|
os.Stdout = old
|
||||||
}()
|
|
||||||
|
|
||||||
fn()
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
_, _ = io.Copy(&buf, r)
|
_, _ = io.Copy(&buf, r)
|
||||||
|
_ = r.Close()
|
||||||
return buf.String()
|
return buf.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,40 +1,22 @@
|
|||||||
package commands
|
package commands
|
||||||
|
|
||||||
import (
|
import (
|
||||||
cryptoRand "crypto/rand"
|
|
||||||
"errors"
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/lib/pq"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
"gorm.io/gorm"
|
||||||
"goyco/internal/config"
|
"goyco/internal/config"
|
||||||
"goyco/internal/database"
|
"goyco/internal/database"
|
||||||
"goyco/internal/repositories"
|
"goyco/internal/repositories"
|
||||||
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
seedRandSource *rand.Rand
|
|
||||||
seedRandOnce sync.Once
|
|
||||||
)
|
|
||||||
|
|
||||||
func initSeedRand() {
|
|
||||||
seedRandOnce.Do(func() {
|
|
||||||
seed := time.Now().UnixNano()
|
|
||||||
seedBytes := make([]byte, 8)
|
|
||||||
if _, err := cryptoRand.Read(seedBytes); err == nil {
|
|
||||||
seed = int64(seedBytes[0])<<56 | int64(seedBytes[1])<<48 | int64(seedBytes[2])<<40 | int64(seedBytes[3])<<32 |
|
|
||||||
int64(seedBytes[4])<<24 | int64(seedBytes[5])<<16 | int64(seedBytes[6])<<8 | int64(seedBytes[7])
|
|
||||||
}
|
|
||||||
seedRandSource = rand.New(rand.NewSource(seed))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func HandleSeedCommand(cfg *config.Config, name string, args []string) error {
|
func HandleSeedCommand(cfg *config.Config, name string, args []string) error {
|
||||||
fs := newFlagSet(name, printSeedUsage)
|
fs := newFlagSet(name, printSeedUsage)
|
||||||
if err := parseCommand(fs, args, name); err != nil {
|
if err := parseCommand(fs, args, name); err != nil {
|
||||||
@@ -45,11 +27,13 @@ func HandleSeedCommand(cfg *config.Config, name string, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return withDatabase(cfg, func(db *gorm.DB) error {
|
return withDatabase(cfg, func(db *gorm.DB) error {
|
||||||
userRepo := repositories.NewUserRepository(db)
|
return db.Transaction(func(tx *gorm.DB) error {
|
||||||
postRepo := repositories.NewPostRepository(db)
|
userRepo := repositories.NewUserRepository(db).WithTx(tx)
|
||||||
voteRepo := repositories.NewVoteRepository(db)
|
postRepo := repositories.NewPostRepository(db).WithTx(tx)
|
||||||
|
voteRepo := repositories.NewVoteRepository(db).WithTx(tx)
|
||||||
return runSeedCommand(userRepo, postRepo, voteRepo, fs.Args())
|
return runSeedCommand(userRepo, postRepo, voteRepo, fs.Args())
|
||||||
})
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func runSeedCommand(userRepo repositories.UserRepository, postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, args []string) error {
|
func runSeedCommand(userRepo repositories.UserRepository, postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, args []string) error {
|
||||||
@@ -78,45 +62,37 @@ func printSeedUsage() {
|
|||||||
fmt.Fprintln(os.Stderr, " --votes-per-post: average votes per post (default: 15)")
|
fmt.Fprintln(os.Stderr, " --votes-per-post: average votes per post (default: 15)")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func clampFlagValue(value *int, min int, name string) {
|
||||||
|
if *value < min {
|
||||||
|
if !IsJSONOutput() {
|
||||||
|
fmt.Fprintf(os.Stderr, "Warning: --%s value %d is too low, clamping to %d\n", name, *value, min)
|
||||||
|
}
|
||||||
|
*value = min
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, args []string) error {
|
func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, args []string) error {
|
||||||
fs := flag.NewFlagSet("seed database", flag.ContinueOnError)
|
fs := flag.NewFlagSet("seed database", flag.ContinueOnError)
|
||||||
numPosts := fs.Int("posts", 40, "number of posts to create")
|
numPosts := fs.Int("posts", 40, "number of posts to create")
|
||||||
numUsers := fs.Int("users", 5, "number of additional users to create")
|
numUsers := fs.Int("users", 5, "number of additional users to create")
|
||||||
votesPerPost := fs.Int("votes-per-post", 15, "average votes per post")
|
votesPerPost := fs.Int("votes-per-post", 15, "average votes per post")
|
||||||
fs.SetOutput(os.Stderr)
|
fs.SetOutput(os.Stderr)
|
||||||
|
fs.Usage = func() {
|
||||||
|
fmt.Fprintln(os.Stderr, "Usage: goyco seed database [--posts <n>] [--users <n>] [--votes-per-post <n>]")
|
||||||
|
fmt.Fprintln(os.Stderr, "\nOptions:")
|
||||||
|
fs.PrintDefaults()
|
||||||
|
}
|
||||||
|
|
||||||
if err := fs.Parse(args); err != nil {
|
if err := fs.Parse(args); err != nil {
|
||||||
|
if errors.Is(err, flag.ErrHelp) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
originalUsers := *numUsers
|
clampFlagValue(numUsers, 0, "users")
|
||||||
originalPosts := *numPosts
|
clampFlagValue(numPosts, 1, "posts")
|
||||||
originalVotesPerPost := *votesPerPost
|
clampFlagValue(votesPerPost, 0, "votes-per-post")
|
||||||
|
|
||||||
if *numUsers < 0 {
|
|
||||||
if !IsJSONOutput() {
|
|
||||||
fmt.Fprintf(os.Stderr, "Warning: --users value %d is negative, clamping to 0\n", *numUsers)
|
|
||||||
}
|
|
||||||
*numUsers = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
if *numPosts <= 0 {
|
|
||||||
if !IsJSONOutput() {
|
|
||||||
fmt.Fprintf(os.Stderr, "Warning: --posts value %d is too low, clamping to 1\n", *numPosts)
|
|
||||||
}
|
|
||||||
*numPosts = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if *votesPerPost < 0 {
|
|
||||||
if !IsJSONOutput() {
|
|
||||||
fmt.Fprintf(os.Stderr, "Warning: --votes-per-post value %d is negative, clamping to 0\n", *votesPerPost)
|
|
||||||
}
|
|
||||||
*votesPerPost = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
if !IsJSONOutput() && (originalUsers != *numUsers || originalPosts != *numPosts || originalVotesPerPost != *votesPerPost) {
|
|
||||||
fmt.Fprintf(os.Stderr, "Using clamped values: --users=%d --posts=%d --votes-per-post=%d\n", *numUsers, *numPosts, *votesPerPost)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !IsJSONOutput() {
|
if !IsJSONOutput() {
|
||||||
fmt.Println("Starting database seeding...")
|
fmt.Println("Starting database seeding...")
|
||||||
@@ -135,71 +111,35 @@ func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.Po
|
|||||||
return fmt.Errorf("precompute user password hash: %w", err)
|
return fmt.Errorf("precompute user password hash: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
spinner := NewSpinner("Creating seed user")
|
|
||||||
if !IsJSONOutput() {
|
|
||||||
spinner.Spin()
|
|
||||||
}
|
|
||||||
|
|
||||||
seedUser, err := ensureSeedUser(userRepo, string(seedPasswordHash))
|
seedUser, err := ensureSeedUser(userRepo, string(seedPasswordHash))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !IsJSONOutput() {
|
|
||||||
spinner.Complete()
|
|
||||||
}
|
|
||||||
return fmt.Errorf("ensure seed user: %w", err)
|
return fmt.Errorf("ensure seed user: %w", err)
|
||||||
}
|
}
|
||||||
if !IsJSONOutput() {
|
if !IsJSONOutput() {
|
||||||
spinner.Complete()
|
|
||||||
fmt.Printf("Seed user ready: ID=%d Username=%s\n", seedUser.ID, seedUser.Username)
|
fmt.Printf("Seed user ready: ID=%d Username=%s\n", seedUser.ID, seedUser.Username)
|
||||||
}
|
}
|
||||||
|
|
||||||
processor := NewParallelProcessor()
|
generator := newSeedGenerator(string(userPasswordHash))
|
||||||
processor.SetPasswordHash(string(userPasswordHash))
|
allUsers := []database.User{*seedUser}
|
||||||
|
|
||||||
var progress *ProgressIndicator
|
users, err := createUsers(generator, userRepo, *numUsers, "Creating users")
|
||||||
if !IsJSONOutput() && *numUsers > 0 {
|
|
||||||
progress = NewProgressIndicator(*numUsers, "Creating users (parallel)")
|
|
||||||
}
|
|
||||||
users, err := processor.CreateUsersInParallel(userRepo, *numUsers, progress)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create random users: %w", err)
|
return err
|
||||||
}
|
|
||||||
if !IsJSONOutput() && progress != nil {
|
|
||||||
progress.Complete()
|
|
||||||
}
|
}
|
||||||
|
allUsers = append(allUsers, users...)
|
||||||
|
|
||||||
allUsers := append([]database.User{*seedUser}, users...)
|
posts, err := createPosts(generator, postRepo, seedUser.ID, *numPosts, "Creating posts")
|
||||||
|
|
||||||
if !IsJSONOutput() && *numPosts > 0 {
|
|
||||||
progress = NewProgressIndicator(*numPosts, "Creating posts (parallel)")
|
|
||||||
}
|
|
||||||
posts, err := processor.CreatePostsInParallel(postRepo, seedUser.ID, *numPosts, progress)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create random posts: %w", err)
|
return err
|
||||||
}
|
|
||||||
if !IsJSONOutput() && progress != nil {
|
|
||||||
progress.Complete()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !IsJSONOutput() && len(posts) > 0 {
|
votes, err := createVotes(generator, voteRepo, allUsers, posts, *votesPerPost, "Creating votes")
|
||||||
progress = NewProgressIndicator(len(posts), "Creating votes (parallel)")
|
|
||||||
}
|
|
||||||
votes, err := processor.CreateVotesInParallel(voteRepo, allUsers, posts, *votesPerPost, progress)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create random votes: %w", err)
|
return err
|
||||||
}
|
|
||||||
if !IsJSONOutput() && progress != nil {
|
|
||||||
progress.Complete()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !IsJSONOutput() && len(posts) > 0 {
|
if err := updateScores(generator, postRepo, voteRepo, posts, "Updating scores"); err != nil {
|
||||||
progress = NewProgressIndicator(len(posts), "Updating scores (parallel)")
|
return err
|
||||||
}
|
|
||||||
err = processor.UpdatePostScoresInParallel(postRepo, voteRepo, posts, progress)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("update post scores: %w", err)
|
|
||||||
}
|
|
||||||
if !IsJSONOutput() && progress != nil {
|
|
||||||
progress.Complete()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := validateSeedConsistency(voteRepo, allUsers, posts); err != nil {
|
if err := validateSeedConsistency(voteRepo, allUsers, posts); err != nil {
|
||||||
@@ -231,11 +171,15 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func ensureSeedUser(userRepo repositories.UserRepository, passwordHash string) (*database.User, error) {
|
func ensureSeedUser(userRepo repositories.UserRepository, passwordHash string) (*database.User, error) {
|
||||||
if user, err := userRepo.GetByUsername(seedUsername); err == nil {
|
user, err := userRepo.GetByUsername(seedUsername)
|
||||||
|
if err == nil {
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, fmt.Errorf("failed to check if seed user exists: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
user := &database.User{
|
user = &database.User{
|
||||||
Username: seedUsername,
|
Username: seedUsername,
|
||||||
Email: seedEmail,
|
Email: seedEmail,
|
||||||
Password: passwordHash,
|
Password: passwordHash,
|
||||||
@@ -249,10 +193,6 @@ func ensureSeedUser(userRepo repositories.UserRepository, passwordHash string) (
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getVoteCounts(voteRepo repositories.VoteRepository, postID uint) (int, int, error) {
|
|
||||||
return voteRepo.GetVoteCountsByPostID(postID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateSeedConsistency(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post) error {
|
func validateSeedConsistency(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post) error {
|
||||||
userIDSet := make(map[uint]struct{}, len(users))
|
userIDSet := make(map[uint]struct{}, len(users))
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
@@ -265,8 +205,11 @@ func validateSeedConsistency(voteRepo repositories.VoteRepository, users []datab
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, post := range posts {
|
for _, post := range posts {
|
||||||
if err := validatePost(post, userIDSet); err != nil {
|
if post.AuthorID == nil {
|
||||||
return err
|
return fmt.Errorf("post %d has no author ID", post.ID)
|
||||||
|
}
|
||||||
|
if _, exists := userIDSet[*post.AuthorID]; !exists {
|
||||||
|
return fmt.Errorf("post %d references non-existent author ID %d", post.ID, *post.AuthorID)
|
||||||
}
|
}
|
||||||
|
|
||||||
votes, err := voteRepo.GetByPostID(post.ID)
|
votes, err := voteRepo.GetByPostID(post.ID)
|
||||||
@@ -274,46 +217,293 @@ func validateSeedConsistency(voteRepo repositories.VoteRepository, users []datab
|
|||||||
return fmt.Errorf("failed to retrieve votes for post %d: %w", post.ID, err)
|
return fmt.Errorf("failed to retrieve votes for post %d: %w", post.ID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := validateVotesForPost(post.ID, votes, userIDSet, postIDSet); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func validatePost(post database.Post, userIDSet map[uint]struct{}) error {
|
|
||||||
if post.AuthorID == nil {
|
|
||||||
return fmt.Errorf("post %d has no author ID", post.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, exists := userIDSet[*post.AuthorID]; !exists {
|
|
||||||
return fmt.Errorf("post %d references non-existent author ID %d", post.ID, *post.AuthorID)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateVotesForPost(postID uint, votes []database.Vote, userIDSet map[uint]struct{}, postIDSet map[uint]struct{}) error {
|
|
||||||
for _, vote := range votes {
|
for _, vote := range votes {
|
||||||
if vote.PostID != postID {
|
if vote.PostID != post.ID {
|
||||||
return fmt.Errorf("vote %d references post ID %d but was retrieved for post %d", vote.ID, vote.PostID, postID)
|
return fmt.Errorf("vote %d references post ID %d but was retrieved for post %d", vote.ID, vote.PostID, post.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, exists := postIDSet[vote.PostID]; !exists {
|
if _, exists := postIDSet[vote.PostID]; !exists {
|
||||||
return fmt.Errorf("vote %d references non-existent post ID %d", vote.ID, vote.PostID)
|
return fmt.Errorf("vote %d references non-existent post ID %d", vote.ID, vote.PostID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if vote.UserID != nil {
|
if vote.UserID != nil {
|
||||||
if _, exists := userIDSet[*vote.UserID]; !exists {
|
if _, exists := userIDSet[*vote.UserID]; !exists {
|
||||||
return fmt.Errorf("vote %d references non-existent user ID %d", vote.ID, *vote.UserID)
|
return fmt.Errorf("vote %d references non-existent user ID %d", vote.ID, *vote.UserID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if vote.Type != database.VoteUp && vote.Type != database.VoteDown {
|
if vote.Type != database.VoteUp && vote.Type != database.VoteDown {
|
||||||
return fmt.Errorf("vote %d has invalid type %q", vote.ID, vote.Type)
|
return fmt.Errorf("vote %d has invalid type %q", vote.ID, vote.Type)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type seedGenerator struct {
|
||||||
|
passwordHash string
|
||||||
|
randSource *rand.Rand
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSeedGenerator(passwordHash string) *seedGenerator {
|
||||||
|
seed := time.Now().UnixNano()
|
||||||
|
return &seedGenerator{
|
||||||
|
passwordHash: passwordHash,
|
||||||
|
randSource: rand.New(rand.NewSource(seed)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRetryableError(err error, keywords ...string) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
errMsg := strings.ToLower(err.Error())
|
||||||
|
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||||
|
for _, keyword := range keywords {
|
||||||
|
if strings.Contains(errMsg, keyword) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var pqErr *pq.Error
|
||||||
|
if errors.As(err, &pqErr) && pqErr.Code == "23505" {
|
||||||
|
constraintLower := strings.ToLower(pqErr.Constraint)
|
||||||
|
errMsgLower := strings.ToLower(pqErr.Message)
|
||||||
|
for _, keyword := range keywords {
|
||||||
|
if strings.Contains(constraintLower, keyword) || strings.Contains(errMsgLower, keyword) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(errMsg, "duplicate") {
|
||||||
|
for _, keyword := range keywords {
|
||||||
|
if strings.Contains(errMsg, keyword) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func createUsers(g *seedGenerator, userRepo repositories.UserRepository, count int, desc string) ([]database.User, error) {
|
||||||
|
if count == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
progress := maybeProgress(count, desc)
|
||||||
|
users := make([]database.User, 0, count)
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
user, err := g.createSingleUser(userRepo, i+1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create random user: %w", err)
|
||||||
|
}
|
||||||
|
users = append(users, user)
|
||||||
|
if progress != nil {
|
||||||
|
progress.Increment()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if progress != nil {
|
||||||
|
progress.Complete()
|
||||||
|
}
|
||||||
|
return users, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createPosts(g *seedGenerator, postRepo repositories.PostRepository, authorID uint, count int, desc string) ([]database.Post, error) {
|
||||||
|
if count == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
progress := maybeProgress(count, desc)
|
||||||
|
posts := make([]database.Post, 0, count)
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
post, err := g.createSinglePost(postRepo, authorID, i+1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create random post: %w", err)
|
||||||
|
}
|
||||||
|
posts = append(posts, post)
|
||||||
|
if progress != nil {
|
||||||
|
progress.Increment()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if progress != nil {
|
||||||
|
progress.Complete()
|
||||||
|
}
|
||||||
|
return posts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createVotes(g *seedGenerator, voteRepo repositories.VoteRepository, users []database.User, posts []database.Post, avgVotesPerPost int, desc string) (int, error) {
|
||||||
|
if len(posts) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
progress := maybeProgress(len(posts), desc)
|
||||||
|
votes := 0
|
||||||
|
for _, post := range posts {
|
||||||
|
count, err := g.createVotesForPost(voteRepo, users, post, avgVotesPerPost)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("create random votes for post %d: %w", post.ID, err)
|
||||||
|
}
|
||||||
|
votes += count
|
||||||
|
if progress != nil {
|
||||||
|
progress.Increment()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if progress != nil {
|
||||||
|
progress.Complete()
|
||||||
|
}
|
||||||
|
return votes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateScores(g *seedGenerator, postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, desc string) error {
|
||||||
|
if len(posts) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
progress := maybeProgress(len(posts), desc)
|
||||||
|
for _, post := range posts {
|
||||||
|
if err := g.updateSinglePostScore(postRepo, voteRepo, post); err != nil {
|
||||||
|
return fmt.Errorf("update post scores: %w", err)
|
||||||
|
}
|
||||||
|
if progress != nil {
|
||||||
|
progress.Increment()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if progress != nil {
|
||||||
|
progress.Complete()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func maybeProgress(count int, desc string) *ProgressIndicator {
|
||||||
|
if !IsJSONOutput() && count > 0 {
|
||||||
|
return NewProgressIndicator(count, desc)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *seedGenerator) generateRandomIdentifier() string {
|
||||||
|
const length = 12
|
||||||
|
const chars = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||||
|
identifier := make([]byte, length)
|
||||||
|
for i := range identifier {
|
||||||
|
identifier[i] = chars[g.randSource.Intn(len(chars))]
|
||||||
|
}
|
||||||
|
return string(identifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *seedGenerator) createSingleUser(userRepo repositories.UserRepository, index int) (database.User, error) {
|
||||||
|
const maxRetries = 10
|
||||||
|
var lastErr error
|
||||||
|
for attempt := range maxRetries {
|
||||||
|
randomID := g.generateRandomIdentifier()
|
||||||
|
user := &database.User{
|
||||||
|
Username: fmt.Sprintf("user_%s", randomID),
|
||||||
|
Email: fmt.Sprintf("user_%s@goyco.local", randomID),
|
||||||
|
Password: g.passwordHash,
|
||||||
|
EmailVerified: true,
|
||||||
|
}
|
||||||
|
if err := userRepo.Create(user); err != nil {
|
||||||
|
lastErr = err
|
||||||
|
if !isRetryableError(err, "username", "email", "users_username_key", "users_email_key") {
|
||||||
|
return database.User{}, fmt.Errorf("failed to create user (attempt %d/%d): %w", attempt+1, maxRetries, err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return *user, nil
|
||||||
|
}
|
||||||
|
return database.User{}, fmt.Errorf("failed to create user after %d attempts: %w", maxRetries, lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
sampleTitles = []string{"Amazing JavaScript Framework", "Python Best Practices", "Go Performance Tips", "Database Optimization", "Web Security Guide", "Machine Learning Basics", "Cloud Architecture", "DevOps Automation", "API Design Patterns", "Frontend Optimization", "Backend Scaling", "Container Orchestration", "Microservices Architecture", "Testing Strategies", "Code Review Process", "Version Control Best Practices", "Continuous Integration", "Monitoring and Alerting", "Error Handling Patterns", "Data Structures Explained"}
|
||||||
|
sampleDomains = []string{"example.com", "techblog.org", "devguide.net", "programming.io", "codeexamples.com", "tutorialhub.org", "bestpractices.dev", "learnprogramming.net", "codingtips.org", "softwareengineering.com"}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (g *seedGenerator) createSinglePost(postRepo repositories.PostRepository, authorID uint, index int) (database.Post, error) {
|
||||||
|
title := sampleTitles[index%len(sampleTitles)]
|
||||||
|
if index >= len(sampleTitles) {
|
||||||
|
title = fmt.Sprintf("%s - Part %d", title, (index/len(sampleTitles))+1)
|
||||||
|
}
|
||||||
|
domain := sampleDomains[index%len(sampleDomains)]
|
||||||
|
content := fmt.Sprintf("Autogenerated seed post #%d\n\nThis is sample content for testing purposes. The post discusses %s and provides valuable insights.", index, title)
|
||||||
|
|
||||||
|
const maxRetries = 10
|
||||||
|
var lastErr error
|
||||||
|
for attempt := range maxRetries {
|
||||||
|
randomID := g.generateRandomIdentifier()
|
||||||
|
post := &database.Post{
|
||||||
|
Title: title,
|
||||||
|
URL: fmt.Sprintf("https://%s/article/%s", domain, randomID),
|
||||||
|
Content: content,
|
||||||
|
AuthorID: &authorID,
|
||||||
|
UpVotes: 0,
|
||||||
|
DownVotes: 0,
|
||||||
|
Score: 0,
|
||||||
|
}
|
||||||
|
if err := postRepo.Create(post); err != nil {
|
||||||
|
lastErr = err
|
||||||
|
if !isRetryableError(err, "url", "posts_url_key") {
|
||||||
|
return database.Post{}, fmt.Errorf("failed to create post (attempt %d/%d): %w", attempt+1, maxRetries, err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return *post, nil
|
||||||
|
}
|
||||||
|
return database.Post{}, fmt.Errorf("failed to create post after %d attempts: %w", maxRetries, lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *seedGenerator) createVotesForPost(voteRepo repositories.VoteRepository, users []database.User, post database.Post, avgVotesPerPost int) (int, error) {
|
||||||
|
numVotes := g.randSource.Intn(avgVotesPerPost*2 + 1)
|
||||||
|
|
||||||
|
if numVotes == 0 && avgVotesPerPost > 0 {
|
||||||
|
if g.randSource.Intn(5) > 0 {
|
||||||
|
numVotes = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
totalVotes := 0
|
||||||
|
usedUsers := make(map[uint]bool)
|
||||||
|
|
||||||
|
for i := 0; i < numVotes && len(usedUsers) < len(users); i++ {
|
||||||
|
userIdx := g.randSource.Intn(len(users))
|
||||||
|
user := users[userIdx]
|
||||||
|
|
||||||
|
if usedUsers[user.ID] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
usedUsers[user.ID] = true
|
||||||
|
|
||||||
|
voteTypeInt := g.randSource.Intn(10)
|
||||||
|
var voteType database.VoteType
|
||||||
|
if voteTypeInt < 7 {
|
||||||
|
voteType = database.VoteUp
|
||||||
|
} else {
|
||||||
|
voteType = database.VoteDown
|
||||||
|
}
|
||||||
|
|
||||||
|
vote := &database.Vote{
|
||||||
|
UserID: &user.ID,
|
||||||
|
PostID: post.ID,
|
||||||
|
Type: voteType,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := voteRepo.CreateOrUpdate(vote); err != nil {
|
||||||
|
return totalVotes, fmt.Errorf("create or update vote: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
totalVotes++
|
||||||
|
}
|
||||||
|
|
||||||
|
return totalVotes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *seedGenerator) updateSinglePostScore(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, post database.Post) error {
|
||||||
|
upVotes, downVotes, err := voteRepo.GetVoteCountsByPostID(post.ID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get vote counts: %w", err)
|
||||||
|
}
|
||||||
|
post.UpVotes = upVotes
|
||||||
|
post.DownVotes = downVotes
|
||||||
|
post.Score = upVotes - downVotes
|
||||||
|
return postRepo.Update(&post)
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,8 +2,11 @@ package commands
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"goyco/internal/database"
|
"goyco/internal/database"
|
||||||
"goyco/internal/repositories"
|
"goyco/internal/repositories"
|
||||||
@@ -13,6 +16,20 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
seedRandSource *rand.Rand
|
||||||
|
seedRandOnce sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
|
const testPasswordHash = "test_password_hash"
|
||||||
|
|
||||||
|
func initSeedRand() {
|
||||||
|
seedRandOnce.Do(func() {
|
||||||
|
seed := time.Now().UnixNano()
|
||||||
|
seedRandSource = rand.New(rand.NewSource(seed))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestSeedCommand(t *testing.T) {
|
func TestSeedCommand(t *testing.T) {
|
||||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -46,11 +63,11 @@ func TestSeedCommand(t *testing.T) {
|
|||||||
seedUserCount := 0
|
seedUserCount := 0
|
||||||
var seedUser *database.User
|
var seedUser *database.User
|
||||||
regularUserCount := 0
|
regularUserCount := 0
|
||||||
for i := range users {
|
for idx := range users {
|
||||||
if users[i].Username == "seed_admin" {
|
if users[idx].Username == seedUsername {
|
||||||
seedUserCount++
|
seedUserCount++
|
||||||
seedUser = &users[i]
|
seedUser = &users[idx]
|
||||||
} else if strings.HasPrefix(users[i].Username, "user_") {
|
} else if strings.HasPrefix(users[idx].Username, "user_") {
|
||||||
regularUserCount++
|
regularUserCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -63,12 +80,12 @@ func TestSeedCommand(t *testing.T) {
|
|||||||
t.Fatal("Expected seed user to be created")
|
t.Fatal("Expected seed user to be created")
|
||||||
}
|
}
|
||||||
|
|
||||||
if seedUser.Username != "seed_admin" {
|
if seedUser.Username != seedUsername {
|
||||||
t.Errorf("Expected username to be 'seed_admin', got '%s'", seedUser.Username)
|
t.Errorf("Expected username to be %q, got '%s'", seedUsername, seedUser.Username)
|
||||||
}
|
}
|
||||||
|
|
||||||
if seedUser.Email != "seed_admin@goyco.local" {
|
if seedUser.Email != seedEmail {
|
||||||
t.Errorf("Expected email to be 'seed_admin@goyco.local', got '%s'", seedUser.Email)
|
t.Errorf("Expected email to be %q, got '%s'", seedEmail, seedUser.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !seedUser.EmailVerified {
|
if !seedUser.EmailVerified {
|
||||||
@@ -88,20 +105,20 @@ func TestSeedCommand(t *testing.T) {
|
|||||||
t.Errorf("Expected 5 posts, got %d", len(posts))
|
t.Errorf("Expected 5 posts, got %d", len(posts))
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, post := range posts {
|
for idx, post := range posts {
|
||||||
if post.Title == "" {
|
if post.Title == "" {
|
||||||
t.Errorf("Post %d has empty title", i)
|
t.Errorf("Post %d has empty title", idx)
|
||||||
}
|
}
|
||||||
if post.URL == "" {
|
if post.URL == "" {
|
||||||
t.Errorf("Post %d has empty URL", i)
|
t.Errorf("Post %d has empty URL", idx)
|
||||||
}
|
}
|
||||||
if post.AuthorID == nil || *post.AuthorID != seedUser.ID {
|
if post.AuthorID == nil || *post.AuthorID != seedUser.ID {
|
||||||
t.Errorf("Post %d has wrong author ID: expected %d, got %v", i, seedUser.ID, post.AuthorID)
|
t.Errorf("Post %d has wrong author ID: expected %d, got %v", idx, seedUser.ID, post.AuthorID)
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedScore := post.UpVotes - post.DownVotes
|
expectedScore := post.UpVotes - post.DownVotes
|
||||||
if post.Score != expectedScore {
|
if post.Score != expectedScore {
|
||||||
t.Errorf("Post %d has incorrect score: expected %d, got %d", i, expectedScore, post.Score)
|
t.Errorf("Post %d has incorrect score: expected %d, got %d", idx, expectedScore, post.Score)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -133,11 +150,12 @@ func TestSeedCommand(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGenerateRandomPath(t *testing.T) {
|
func TestGenerateRandomPath(t *testing.T) {
|
||||||
|
const articlePathPrefix = "/article/"
|
||||||
initSeedRand()
|
initSeedRand()
|
||||||
pathLength := seedRandSource.Intn(20)
|
pathLength := seedRandSource.Intn(20)
|
||||||
path := "/article/"
|
path := articlePathPrefix
|
||||||
|
|
||||||
for i := 0; i < pathLength+5; i++ {
|
for idx := 0; idx < pathLength+5; idx++ {
|
||||||
randomChar := seedRandSource.Intn(26)
|
randomChar := seedRandSource.Intn(26)
|
||||||
path += string(rune('a' + randomChar))
|
path += string(rune('a' + randomChar))
|
||||||
}
|
}
|
||||||
@@ -152,13 +170,14 @@ func TestGenerateRandomPath(t *testing.T) {
|
|||||||
|
|
||||||
initSeedRand()
|
initSeedRand()
|
||||||
secondPathLength := seedRandSource.Intn(20)
|
secondPathLength := seedRandSource.Intn(20)
|
||||||
secondPath := "/article/"
|
var secondPath strings.Builder
|
||||||
for i := 0; i < secondPathLength+5; i++ {
|
secondPath.WriteString(articlePathPrefix)
|
||||||
|
for idx := 0; idx < secondPathLength+5; idx++ {
|
||||||
randomChar := seedRandSource.Intn(26)
|
randomChar := seedRandSource.Intn(26)
|
||||||
secondPath += string(rune('a' + randomChar))
|
secondPath.WriteString(string(rune('a' + randomChar)))
|
||||||
}
|
}
|
||||||
|
|
||||||
if path == secondPath {
|
if path == secondPath.String() {
|
||||||
t.Error("Generated paths should be different")
|
t.Error("Generated paths should be different")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -271,6 +290,22 @@ func TestSeedDatabaseFlagParsing(t *testing.T) {
|
|||||||
t.Errorf("zero votes-per-post should be valid, got error: %v", err)
|
t.Errorf("zero votes-per-post should be valid, got error: %v", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("help flag returns no error", func(t *testing.T) {
|
||||||
|
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--help"})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("help flag should return no error, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("short help flag returns no error", func(t *testing.T) {
|
||||||
|
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"-h"})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("short help flag should return no error, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSeedCommandIdempotency(t *testing.T) {
|
func TestSeedCommandIdempotency(t *testing.T) {
|
||||||
@@ -302,7 +337,7 @@ func TestSeedCommandIdempotency(t *testing.T) {
|
|||||||
|
|
||||||
seedUserCount := 0
|
seedUserCount := 0
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
if user.Username == "seed_admin" {
|
if user.Username == seedUsername {
|
||||||
seedUserCount++
|
seedUserCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -338,10 +373,10 @@ func TestSeedCommandIdempotency(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("database remains consistent after multiple runs", func(t *testing.T) {
|
t.Run("database remains consistent after multiple runs", func(t *testing.T) {
|
||||||
for i := range 2 {
|
for idx := range 2 {
|
||||||
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--users", "0", "--posts", "1"})
|
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--users", "0", "--posts", "1"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Seed run %d failed: %v", i+1, err)
|
t.Fatalf("Seed run %d failed: %v", idx+1, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -386,9 +421,9 @@ func TestSeedCommandIdempotency(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func findSeedUser(users []database.User) *database.User {
|
func findSeedUser(users []database.User) *database.User {
|
||||||
for i := range users {
|
for idx := range users {
|
||||||
if users[i].Username == "seed_admin" {
|
if users[idx].Username == seedUsername {
|
||||||
return &users[i]
|
return &users[idx]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -488,14 +523,14 @@ func TestEnsureSeedUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
userRepo := repositories.NewUserRepository(db)
|
userRepo := repositories.NewUserRepository(db)
|
||||||
passwordHash := "test_password_hash"
|
passwordHash := testPasswordHash
|
||||||
|
|
||||||
firstUser, err := ensureSeedUser(userRepo, passwordHash)
|
firstUser, err := ensureSeedUser(userRepo, passwordHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create seed user: %v", err)
|
t.Fatalf("Failed to create seed user: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if firstUser.Username != "seed_admin" || firstUser.Email != "seed_admin@goyco.local" || firstUser.Password != passwordHash || !firstUser.EmailVerified {
|
if firstUser.Username != seedUsername || firstUser.Email != seedEmail || firstUser.Password != passwordHash || !firstUser.EmailVerified {
|
||||||
t.Errorf("Invalid seed user: username=%s, email=%s, password matches=%v, emailVerified=%v",
|
t.Errorf("Invalid seed user: username=%s, email=%s, password matches=%v, emailVerified=%v",
|
||||||
firstUser.Username, firstUser.Email, firstUser.Password == passwordHash, firstUser.EmailVerified)
|
firstUser.Username, firstUser.Email, firstUser.Password == passwordHash, firstUser.EmailVerified)
|
||||||
}
|
}
|
||||||
@@ -509,9 +544,9 @@ func TestEnsureSeedUser(t *testing.T) {
|
|||||||
t.Errorf("Expected same user to be reused (ID %d), got different user (ID %d)", firstUser.ID, secondUser.ID)
|
t.Errorf("Expected same user to be reused (ID %d), got different user (ID %d)", firstUser.ID, secondUser.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < 3; i++ {
|
for idx := range 3 {
|
||||||
if _, err := ensureSeedUser(userRepo, passwordHash); err != nil {
|
if _, err := ensureSeedUser(userRepo, passwordHash); err != nil {
|
||||||
t.Fatalf("Call %d failed: %v", i+1, err)
|
t.Fatalf("Call %d failed: %v", idx+1, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -522,7 +557,7 @@ func TestEnsureSeedUser(t *testing.T) {
|
|||||||
|
|
||||||
seedUserCount := 0
|
seedUserCount := 0
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
if user.Username == "seed_admin" {
|
if user.Username == seedUsername {
|
||||||
seedUserCount++
|
seedUserCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -531,3 +566,25 @@ func TestEnsureSeedUser(t *testing.T) {
|
|||||||
t.Errorf("Expected exactly 1 seed user, got %d", seedUserCount)
|
t.Errorf("Expected exactly 1 seed user, got %d", seedUserCount)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEnsureSeedUser_HandlesDatabaseErrors(t *testing.T) {
|
||||||
|
userRepo := testutils.NewMockUserRepository()
|
||||||
|
passwordHash := testPasswordHash
|
||||||
|
|
||||||
|
dbError := fmt.Errorf("database connection failed")
|
||||||
|
|
||||||
|
userRepo.SetGetByUsernameError(dbError)
|
||||||
|
|
||||||
|
_, err := ensureSeedUser(userRepo, passwordHash)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error when GetByUsername returns database error")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "failed to check if seed user exists") {
|
||||||
|
t.Errorf("Expected error message about checking seed user, got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), dbError.Error()) {
|
||||||
|
t.Errorf("Expected error to wrap original database error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"flag"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
@@ -12,6 +11,7 @@ import (
|
|||||||
"goyco/internal/config"
|
"goyco/internal/config"
|
||||||
"goyco/internal/testutils"
|
"goyco/internal/testutils"
|
||||||
|
|
||||||
|
"github.com/urfave/cli/v3"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -36,32 +36,15 @@ func FuzzCLIArgs(f *testing.F) {
|
|||||||
if len(args) == 0 {
|
if len(args) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
cmd := buildRootCommand(testutils.NewTestConfig())
|
||||||
|
for _, sub := range cmd.Commands {
|
||||||
|
sub.Action = func(context.Context, *cli.Command) error { return nil }
|
||||||
|
}
|
||||||
|
|
||||||
fs := flag.NewFlagSet("goyco", flag.ContinueOnError)
|
err := cmd.Run(context.Background(), append([]string{"goyco"}, args...))
|
||||||
fs.SetOutput(os.Stderr)
|
|
||||||
fs.Usage = printRootUsage
|
|
||||||
showHelp := fs.Bool("help", false, "show this help message")
|
|
||||||
|
|
||||||
err := fs.Parse(args)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !strings.Contains(err.Error(), "flag") && !strings.Contains(err.Error(), "help") {
|
if !strings.Contains(err.Error(), "flag") && !strings.Contains(err.Error(), "unknown command") {
|
||||||
t.Logf("Unexpected error format from flag parsing: %v", err)
|
t.Logf("Unexpected error format from command parsing: %v", err)
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if *showHelp && err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
remaining := fs.Args()
|
|
||||||
if len(remaining) > 0 {
|
|
||||||
cmdName := remaining[0]
|
|
||||||
if len(cmdName) == 0 {
|
|
||||||
t.Fatal("Command name cannot be empty")
|
|
||||||
}
|
|
||||||
if !isValidUTF8(cmdName) {
|
|
||||||
t.Fatal("Command name must be valid UTF-8")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -96,12 +79,6 @@ func FuzzCommandDispatch(f *testing.F) {
|
|||||||
})
|
})
|
||||||
defer commands.SetDBConnector(nil)
|
defer commands.SetDBConnector(nil)
|
||||||
|
|
||||||
daemonCommands := map[string]bool{
|
|
||||||
"start": true,
|
|
||||||
"stop": true,
|
|
||||||
"status": true,
|
|
||||||
}
|
|
||||||
|
|
||||||
f.Add("run")
|
f.Add("run")
|
||||||
f.Add("help")
|
f.Add("help")
|
||||||
f.Add("user")
|
f.Add("user")
|
||||||
@@ -121,18 +98,20 @@ func FuzzCommandDispatch(f *testing.F) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cmd := buildRootCommand(cfg)
|
||||||
|
for _, sub := range cmd.Commands {
|
||||||
|
sub.Action = func(context.Context, *cli.Command) error { return nil }
|
||||||
|
}
|
||||||
|
|
||||||
cmdName := parts[0]
|
cmdName := parts[0]
|
||||||
args := parts[1:]
|
args := parts[1:]
|
||||||
|
|
||||||
if daemonCommands[cmdName] {
|
err := cmd.Run(context.Background(), append([]string{"goyco"}, append([]string{cmdName}, args...)...))
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err := dispatchCommand(cfg, cmdName, args)
|
|
||||||
|
|
||||||
knownCommands := map[string]bool{
|
knownCommands := map[string]bool{
|
||||||
"run": true, "user": true, "post": true, "prune": true, "migrate": true,
|
"run": true, "user": true, "post": true, "prune": true, "migrate": true,
|
||||||
"migrations": true, "seed": true, "help": true, "-h": true, "--help": true,
|
"migrations": true, "seed": true, "help": true, "-h": true, "--help": true,
|
||||||
|
"start": true, "stop": true, "status": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
if knownCommands[cmdName] {
|
if knownCommands[cmdName] {
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
// @title Goyco API
|
// @title Goyco API
|
||||||
// @version 0.1.0
|
// @version 0.1.1
|
||||||
// @description Goyco is a Y Combinator-style news aggregation platform API.
|
// @description Goyco is a Y Combinator-style news aggregation platform API.
|
||||||
// @contact.name Goyco Team
|
// @contact.name Goyco Team
|
||||||
// @contact.email sandro@cazzaniga.fr
|
// @contact.email sandro@cazzaniga.fr
|
||||||
@@ -12,8 +12,8 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"flag"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
@@ -55,7 +55,7 @@ func run(args []string) error {
|
|||||||
|
|
||||||
docs.SwaggerInfo.Title = fmt.Sprintf("%s API", cfg.App.Title)
|
docs.SwaggerInfo.Title = fmt.Sprintf("%s API", cfg.App.Title)
|
||||||
docs.SwaggerInfo.Description = "Y Combinator-style news board API."
|
docs.SwaggerInfo.Description = "Y Combinator-style news board API."
|
||||||
docs.SwaggerInfo.Version = version.Version
|
docs.SwaggerInfo.Version = version.GetVersion()
|
||||||
docs.SwaggerInfo.BasePath = "/api"
|
docs.SwaggerInfo.BasePath = "/api"
|
||||||
docs.SwaggerInfo.Host = fmt.Sprintf("%s:%s", cfg.Server.Host, cfg.Server.Port)
|
docs.SwaggerInfo.Host = fmt.Sprintf("%s:%s", cfg.Server.Host, cfg.Server.Port)
|
||||||
docs.SwaggerInfo.Schemes = []string{"http"}
|
docs.SwaggerInfo.Schemes = []string{"http"}
|
||||||
@@ -63,62 +63,9 @@ func run(args []string) error {
|
|||||||
docs.SwaggerInfo.Schemes = append(docs.SwaggerInfo.Schemes, "https")
|
docs.SwaggerInfo.Schemes = append(docs.SwaggerInfo.Schemes, "https")
|
||||||
}
|
}
|
||||||
|
|
||||||
rootFS := flag.NewFlagSet("goyco", flag.ContinueOnError)
|
root := buildRootCommand(cfg)
|
||||||
rootFS.SetOutput(os.Stderr)
|
runArgs := append([]string{os.Args[0]}, args...)
|
||||||
rootFS.Usage = printRootUsage
|
return root.Run(context.Background(), runArgs)
|
||||||
showHelp := rootFS.Bool("help", false, "show this help message")
|
|
||||||
jsonOutput := rootFS.Bool("json", cfg.CLI.JSONOutputDefault, "output results in JSON format")
|
|
||||||
|
|
||||||
if err := rootFS.Parse(args); err != nil {
|
|
||||||
if errors.Is(err, flag.ErrHelp) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return fmt.Errorf("failed to parse arguments: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if *showHelp {
|
|
||||||
printRootUsage()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
commands.SetJSONOutput(*jsonOutput)
|
|
||||||
|
|
||||||
remaining := rootFS.Args()
|
|
||||||
if len(remaining) == 0 {
|
|
||||||
printRootUsage()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return dispatchCommand(cfg, remaining[0], remaining[1:])
|
|
||||||
}
|
|
||||||
|
|
||||||
func dispatchCommand(cfg *config.Config, name string, args []string) error {
|
|
||||||
switch name {
|
|
||||||
case "run":
|
|
||||||
return handleRunCommand(cfg, args)
|
|
||||||
case "start":
|
|
||||||
return commands.HandleStartCommand(cfg, args)
|
|
||||||
case "stop":
|
|
||||||
return commands.HandleStopCommand(cfg, args)
|
|
||||||
case "status":
|
|
||||||
return commands.HandleStatusCommand(cfg, name, args)
|
|
||||||
case "user":
|
|
||||||
return commands.HandleUserCommand(cfg, name, args)
|
|
||||||
case "post":
|
|
||||||
return commands.HandlePostCommand(cfg, name, args)
|
|
||||||
case "prune":
|
|
||||||
return commands.HandlePruneCommand(cfg, name, args)
|
|
||||||
case "migrate", "migrations":
|
|
||||||
return commands.HandleMigrateCommand(cfg, name, args)
|
|
||||||
case "seed":
|
|
||||||
return commands.HandleSeedCommand(cfg, name, args)
|
|
||||||
case "help", "-h", "--help":
|
|
||||||
printRootUsage()
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
printRootUsage()
|
|
||||||
return fmt.Errorf("unknown command: %s", name)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleRunCommand(cfg *config.Config, args []string) error {
|
func handleRunCommand(cfg *config.Config, args []string) error {
|
||||||
|
|||||||
@@ -3,14 +3,13 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
|
||||||
"flag"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"goyco/cmd/goyco/commands"
|
||||||
"goyco/internal/config"
|
"goyco/internal/config"
|
||||||
"goyco/internal/database"
|
"goyco/internal/database"
|
||||||
"goyco/internal/handlers"
|
"goyco/internal/handlers"
|
||||||
@@ -76,6 +75,10 @@ func TestServerConfigurationFromConfig(t *testing.T) {
|
|||||||
IdleTimeout: cfg.Server.IdleTimeout,
|
IdleTimeout: cfg.Server.IdleTimeout,
|
||||||
MaxHeaderBytes: cfg.Server.MaxHeaderBytes,
|
MaxHeaderBytes: cfg.Server.MaxHeaderBytes,
|
||||||
}
|
}
|
||||||
|
expectedAddr := cfg.Server.Host + ":" + cfg.Server.Port
|
||||||
|
if srv.Addr != expectedAddr {
|
||||||
|
t.Errorf("Expected server address to be %q, got %q", expectedAddr, srv.Addr)
|
||||||
|
}
|
||||||
|
|
||||||
if srv.ReadTimeout != 30*time.Second {
|
if srv.ReadTimeout != 30*time.Second {
|
||||||
t.Errorf("Expected ReadTimeout to be 30s, got %v", srv.ReadTimeout)
|
t.Errorf("Expected ReadTimeout to be 30s, got %v", srv.ReadTimeout)
|
||||||
@@ -172,6 +175,10 @@ func TestTLSWiringFromConfig(t *testing.T) {
|
|||||||
t.Errorf("Expected server address to be %q, got %q", expectedAddr, srv.Addr)
|
t.Errorf("Expected server address to be %q, got %q", expectedAddr, srv.Addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if srv.ReadHeaderTimeout != 5*time.Second {
|
||||||
|
t.Errorf("Expected ReadHeaderTimeout to be 5s, got %v", srv.ReadHeaderTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
if cfg.Server.EnableTLS {
|
if cfg.Server.EnableTLS {
|
||||||
srv.TLSConfig = &tls.Config{
|
srv.TLSConfig = &tls.Config{
|
||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
@@ -267,36 +274,37 @@ func TestConfigLoadingInCLI(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFlagParsingInCLI(t *testing.T) {
|
func TestFlagParsingInCLI(t *testing.T) {
|
||||||
originalArgs := os.Args
|
|
||||||
defer func() {
|
|
||||||
os.Args = originalArgs
|
|
||||||
}()
|
|
||||||
|
|
||||||
t.Run("help flag", func(t *testing.T) {
|
t.Run("help flag", func(t *testing.T) {
|
||||||
os.Args = []string{"goyco", "--help"}
|
cmd := buildRootCommand(testutils.NewTestConfig())
|
||||||
fs := flag.NewFlagSet("goyco", flag.ContinueOnError)
|
err := cmd.Run(context.Background(), []string{"goyco", "--help"})
|
||||||
fs.SetOutput(os.Stderr)
|
if err != nil {
|
||||||
showHelp := fs.Bool("help", false, "show help")
|
|
||||||
|
|
||||||
err := fs.Parse([]string{"--help"})
|
|
||||||
if err != nil && !errors.Is(err, flag.ErrHelp) {
|
|
||||||
t.Errorf("Expected help flag parsing, got error: %v", err)
|
t.Errorf("Expected help flag parsing, got error: %v", err)
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
|
||||||
if !*showHelp {
|
t.Run("json flag", func(t *testing.T) {
|
||||||
t.Error("Expected help flag to be true")
|
cmd := buildRootCommand(testutils.NewTestConfig())
|
||||||
|
err := cmd.Run(context.Background(), []string{"goyco", "--json"})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected json flag parsing, got error: %v", err)
|
||||||
}
|
}
|
||||||
|
if !commands.IsJSONOutput() {
|
||||||
|
t.Error("Expected json output to be enabled")
|
||||||
|
}
|
||||||
|
commands.SetJSONOutput(false)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("command dispatch", func(t *testing.T) {
|
t.Run("command dispatch", func(t *testing.T) {
|
||||||
cfg := testutils.NewTestConfig()
|
cfg := testutils.NewTestConfig()
|
||||||
|
|
||||||
err := dispatchCommand(cfg, "unknown", []string{})
|
cmd := buildRootCommand(cfg)
|
||||||
|
err := cmd.Run(context.Background(), []string{"goyco", "unknown"})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected error for unknown command")
|
t.Error("Expected error for unknown command")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = dispatchCommand(cfg, "help", []string{})
|
cmd = buildRootCommand(cfg)
|
||||||
|
err = cmd.Run(context.Background(), []string{"goyco", "help"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Help command should not error: %v", err)
|
t.Errorf("Help command should not error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -364,6 +372,26 @@ func TestServerInitializationFlow(t *testing.T) {
|
|||||||
if srv.Handler == nil {
|
if srv.Handler == nil {
|
||||||
t.Error("Expected server handler to be set")
|
t.Error("Expected server handler to be set")
|
||||||
}
|
}
|
||||||
|
expectedAddr := cfg.Server.Host + ":" + cfg.Server.Port
|
||||||
|
if srv.Addr != expectedAddr {
|
||||||
|
t.Errorf("Expected server address to be %q, got %q", expectedAddr, srv.Addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if srv.ReadTimeout != cfg.Server.ReadTimeout {
|
||||||
|
t.Errorf("Expected ReadTimeout to be %v, got %v", cfg.Server.ReadTimeout, srv.ReadTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
if srv.WriteTimeout != cfg.Server.WriteTimeout {
|
||||||
|
t.Errorf("Expected WriteTimeout to be %v, got %v", cfg.Server.WriteTimeout, srv.WriteTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
if srv.IdleTimeout != cfg.Server.IdleTimeout {
|
||||||
|
t.Errorf("Expected IdleTimeout to be %v, got %v", cfg.Server.IdleTimeout, srv.IdleTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
if srv.MaxHeaderBytes != cfg.Server.MaxHeaderBytes {
|
||||||
|
t.Errorf("Expected MaxHeaderBytes to be %d, got %d", cfg.Server.MaxHeaderBytes, srv.MaxHeaderBytes)
|
||||||
|
}
|
||||||
|
|
||||||
testServer := httptest.NewServer(srv.Handler)
|
testServer := httptest.NewServer(srv.Handler)
|
||||||
defer testServer.Close()
|
defer testServer.Close()
|
||||||
|
|||||||
69
docs/docs.go
69
docs/docs.go
@@ -324,7 +324,7 @@ const docTemplate = `{
|
|||||||
"200": {
|
"200": {
|
||||||
"description": "Authentication successful",
|
"description": "Authentication successful",
|
||||||
"schema": {
|
"schema": {
|
||||||
"$ref": "#/definitions/handlers.AuthTokensResponse"
|
"$ref": "#/definitions/handlers.AuthResponse"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"400": {
|
"400": {
|
||||||
@@ -487,7 +487,7 @@ const docTemplate = `{
|
|||||||
},
|
},
|
||||||
"/api/auth/refresh": {
|
"/api/auth/refresh": {
|
||||||
"post": {
|
"post": {
|
||||||
"description": "Use a refresh token to get a new access token. This endpoint allows clients to obtain a new access token using a valid refresh token without requiring user credentials.",
|
"description": "Use a refresh token to get a new access token. The refresh token is rotated on success, and the previous refresh token becomes invalid.",
|
||||||
"consumes": [
|
"consumes": [
|
||||||
"application/json"
|
"application/json"
|
||||||
],
|
],
|
||||||
@@ -513,7 +513,7 @@ const docTemplate = `{
|
|||||||
"200": {
|
"200": {
|
||||||
"description": "Token refreshed successfully",
|
"description": "Token refreshed successfully",
|
||||||
"schema": {
|
"schema": {
|
||||||
"$ref": "#/definitions/handlers.AuthTokensResponse"
|
"$ref": "#/definitions/handlers.AuthResponse"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"400": {
|
"400": {
|
||||||
@@ -1906,7 +1906,7 @@ const docTemplate = `{
|
|||||||
"properties": {
|
"properties": {
|
||||||
"refresh_token": {
|
"refresh_token": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
|
"example": "f94d4ddc7d9b4fcb9d3a2c44c400b780c3e1f1a5c2b7d4e6a0b1c2d3e4f5a6b7"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -1971,7 +1971,7 @@ const docTemplate = `{
|
|||||||
"properties": {
|
"properties": {
|
||||||
"refresh_token": {
|
"refresh_token": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
|
"example": "f94d4ddc7d9b4fcb9d3a2c44c400b780c3e1f1a5c2b7d4e6a0b1c2d3e4f5a6b7"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -2064,63 +2064,6 @@ const docTemplate = `{
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"handlers.AuthTokensDetail": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"access_token": {
|
|
||||||
"type": "string",
|
|
||||||
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
|
|
||||||
},
|
|
||||||
"refresh_token": {
|
|
||||||
"type": "string",
|
|
||||||
"example": "f94d4ddc7d9b4fcb9d3a2c44c400b780"
|
|
||||||
},
|
|
||||||
"user": {
|
|
||||||
"$ref": "#/definitions/handlers.AuthUserSummary"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"handlers.AuthTokensResponse": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"data": {
|
|
||||||
"$ref": "#/definitions/handlers.AuthTokensDetail"
|
|
||||||
},
|
|
||||||
"message": {
|
|
||||||
"type": "string",
|
|
||||||
"example": "Authentication successful"
|
|
||||||
},
|
|
||||||
"success": {
|
|
||||||
"type": "boolean",
|
|
||||||
"example": true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"handlers.AuthUserSummary": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"email": {
|
|
||||||
"type": "string",
|
|
||||||
"example": "jane@example.com"
|
|
||||||
},
|
|
||||||
"email_verified": {
|
|
||||||
"type": "boolean",
|
|
||||||
"example": true
|
|
||||||
},
|
|
||||||
"id": {
|
|
||||||
"type": "integer",
|
|
||||||
"example": 42
|
|
||||||
},
|
|
||||||
"locked": {
|
|
||||||
"type": "boolean",
|
|
||||||
"example": false
|
|
||||||
},
|
|
||||||
"username": {
|
|
||||||
"type": "string",
|
|
||||||
"example": "janedoe"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"handlers.CommonResponse": {
|
"handlers.CommonResponse": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -2186,7 +2129,7 @@ const docTemplate = `{
|
|||||||
|
|
||||||
// SwaggerInfo holds exported Swagger Info so clients can modify it
|
// SwaggerInfo holds exported Swagger Info so clients can modify it
|
||||||
var SwaggerInfo = &swag.Spec{
|
var SwaggerInfo = &swag.Spec{
|
||||||
Version: "0.1.0",
|
Version: "0.1.1",
|
||||||
Host: "localhost:8080",
|
Host: "localhost:8080",
|
||||||
BasePath: "/api",
|
BasePath: "/api",
|
||||||
Schemes: []string{"http"},
|
Schemes: []string{"http"},
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
"name": "GPLv3",
|
"name": "GPLv3",
|
||||||
"url": "https://www.gnu.org/licenses/gpl-3.0.html"
|
"url": "https://www.gnu.org/licenses/gpl-3.0.html"
|
||||||
},
|
},
|
||||||
"version": "0.1.0"
|
"version": "0.1.1"
|
||||||
},
|
},
|
||||||
"host": "localhost:8080",
|
"host": "localhost:8080",
|
||||||
"basePath": "/api",
|
"basePath": "/api",
|
||||||
@@ -321,7 +321,7 @@
|
|||||||
"200": {
|
"200": {
|
||||||
"description": "Authentication successful",
|
"description": "Authentication successful",
|
||||||
"schema": {
|
"schema": {
|
||||||
"$ref": "#/definitions/handlers.AuthTokensResponse"
|
"$ref": "#/definitions/handlers.AuthResponse"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"400": {
|
"400": {
|
||||||
@@ -484,7 +484,7 @@
|
|||||||
},
|
},
|
||||||
"/api/auth/refresh": {
|
"/api/auth/refresh": {
|
||||||
"post": {
|
"post": {
|
||||||
"description": "Use a refresh token to get a new access token. This endpoint allows clients to obtain a new access token using a valid refresh token without requiring user credentials.",
|
"description": "Use a refresh token to get a new access token. The refresh token is rotated on success, and the previous refresh token becomes invalid.",
|
||||||
"consumes": [
|
"consumes": [
|
||||||
"application/json"
|
"application/json"
|
||||||
],
|
],
|
||||||
@@ -510,7 +510,7 @@
|
|||||||
"200": {
|
"200": {
|
||||||
"description": "Token refreshed successfully",
|
"description": "Token refreshed successfully",
|
||||||
"schema": {
|
"schema": {
|
||||||
"$ref": "#/definitions/handlers.AuthTokensResponse"
|
"$ref": "#/definitions/handlers.AuthResponse"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"400": {
|
"400": {
|
||||||
@@ -1903,7 +1903,7 @@
|
|||||||
"properties": {
|
"properties": {
|
||||||
"refresh_token": {
|
"refresh_token": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
|
"example": "f94d4ddc7d9b4fcb9d3a2c44c400b780c3e1f1a5c2b7d4e6a0b1c2d3e4f5a6b7"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -1968,7 +1968,7 @@
|
|||||||
"properties": {
|
"properties": {
|
||||||
"refresh_token": {
|
"refresh_token": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
|
"example": "f94d4ddc7d9b4fcb9d3a2c44c400b780c3e1f1a5c2b7d4e6a0b1c2d3e4f5a6b7"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -2061,63 +2061,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"handlers.AuthTokensDetail": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"access_token": {
|
|
||||||
"type": "string",
|
|
||||||
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
|
|
||||||
},
|
|
||||||
"refresh_token": {
|
|
||||||
"type": "string",
|
|
||||||
"example": "f94d4ddc7d9b4fcb9d3a2c44c400b780"
|
|
||||||
},
|
|
||||||
"user": {
|
|
||||||
"$ref": "#/definitions/handlers.AuthUserSummary"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"handlers.AuthTokensResponse": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"data": {
|
|
||||||
"$ref": "#/definitions/handlers.AuthTokensDetail"
|
|
||||||
},
|
|
||||||
"message": {
|
|
||||||
"type": "string",
|
|
||||||
"example": "Authentication successful"
|
|
||||||
},
|
|
||||||
"success": {
|
|
||||||
"type": "boolean",
|
|
||||||
"example": true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"handlers.AuthUserSummary": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"email": {
|
|
||||||
"type": "string",
|
|
||||||
"example": "jane@example.com"
|
|
||||||
},
|
|
||||||
"email_verified": {
|
|
||||||
"type": "boolean",
|
|
||||||
"example": true
|
|
||||||
},
|
|
||||||
"id": {
|
|
||||||
"type": "integer",
|
|
||||||
"example": 42
|
|
||||||
},
|
|
||||||
"locked": {
|
|
||||||
"type": "boolean",
|
|
||||||
"example": false
|
|
||||||
},
|
|
||||||
"username": {
|
|
||||||
"type": "string",
|
|
||||||
"example": "janedoe"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"handlers.CommonResponse": {
|
"handlers.CommonResponse": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ definitions:
|
|||||||
dto.RefreshTokenRequest:
|
dto.RefreshTokenRequest:
|
||||||
properties:
|
properties:
|
||||||
refresh_token:
|
refresh_token:
|
||||||
example: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
|
example: f94d4ddc7d9b4fcb9d3a2c44c400b780c3e1f1a5c2b7d4e6a0b1c2d3e4f5a6b7
|
||||||
type: string
|
type: string
|
||||||
required:
|
required:
|
||||||
- refresh_token
|
- refresh_token
|
||||||
@@ -105,7 +105,7 @@ definitions:
|
|||||||
dto.RevokeTokenRequest:
|
dto.RevokeTokenRequest:
|
||||||
properties:
|
properties:
|
||||||
refresh_token:
|
refresh_token:
|
||||||
example: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
|
example: f94d4ddc7d9b4fcb9d3a2c44c400b780c3e1f1a5c2b7d4e6a0b1c2d3e4f5a6b7
|
||||||
type: string
|
type: string
|
||||||
required:
|
required:
|
||||||
- refresh_token
|
- refresh_token
|
||||||
@@ -171,46 +171,6 @@ definitions:
|
|||||||
success:
|
success:
|
||||||
type: boolean
|
type: boolean
|
||||||
type: object
|
type: object
|
||||||
handlers.AuthTokensDetail:
|
|
||||||
properties:
|
|
||||||
access_token:
|
|
||||||
example: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
|
|
||||||
type: string
|
|
||||||
refresh_token:
|
|
||||||
example: f94d4ddc7d9b4fcb9d3a2c44c400b780
|
|
||||||
type: string
|
|
||||||
user:
|
|
||||||
$ref: '#/definitions/handlers.AuthUserSummary'
|
|
||||||
type: object
|
|
||||||
handlers.AuthTokensResponse:
|
|
||||||
properties:
|
|
||||||
data:
|
|
||||||
$ref: '#/definitions/handlers.AuthTokensDetail'
|
|
||||||
message:
|
|
||||||
example: Authentication successful
|
|
||||||
type: string
|
|
||||||
success:
|
|
||||||
example: true
|
|
||||||
type: boolean
|
|
||||||
type: object
|
|
||||||
handlers.AuthUserSummary:
|
|
||||||
properties:
|
|
||||||
email:
|
|
||||||
example: jane@example.com
|
|
||||||
type: string
|
|
||||||
email_verified:
|
|
||||||
example: true
|
|
||||||
type: boolean
|
|
||||||
id:
|
|
||||||
example: 42
|
|
||||||
type: integer
|
|
||||||
locked:
|
|
||||||
example: false
|
|
||||||
type: boolean
|
|
||||||
username:
|
|
||||||
example: janedoe
|
|
||||||
type: string
|
|
||||||
type: object
|
|
||||||
handlers.CommonResponse:
|
handlers.CommonResponse:
|
||||||
properties:
|
properties:
|
||||||
data: {}
|
data: {}
|
||||||
@@ -261,7 +221,7 @@ info:
|
|||||||
name: GPLv3
|
name: GPLv3
|
||||||
url: https://www.gnu.org/licenses/gpl-3.0.html
|
url: https://www.gnu.org/licenses/gpl-3.0.html
|
||||||
title: Goyco API
|
title: Goyco API
|
||||||
version: 0.1.0
|
version: 0.1.1
|
||||||
paths:
|
paths:
|
||||||
/api:
|
/api:
|
||||||
get:
|
get:
|
||||||
@@ -459,7 +419,7 @@ paths:
|
|||||||
"200":
|
"200":
|
||||||
description: Authentication successful
|
description: Authentication successful
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/definitions/handlers.AuthTokensResponse'
|
$ref: '#/definitions/handlers.AuthResponse'
|
||||||
"400":
|
"400":
|
||||||
description: Invalid request data or validation failed
|
description: Invalid request data or validation failed
|
||||||
schema:
|
schema:
|
||||||
@@ -565,9 +525,8 @@ paths:
|
|||||||
post:
|
post:
|
||||||
consumes:
|
consumes:
|
||||||
- application/json
|
- application/json
|
||||||
description: Use a refresh token to get a new access token. This endpoint allows
|
description: Use a refresh token to get a new access token. The refresh token
|
||||||
clients to obtain a new access token using a valid refresh token without requiring
|
is rotated on success, and the previous refresh token becomes invalid.
|
||||||
user credentials.
|
|
||||||
parameters:
|
parameters:
|
||||||
- description: Refresh token data
|
- description: Refresh token data
|
||||||
in: body
|
in: body
|
||||||
@@ -581,7 +540,7 @@ paths:
|
|||||||
"200":
|
"200":
|
||||||
description: Token refreshed successfully
|
description: Token refreshed successfully
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/definitions/handlers.AuthTokensResponse'
|
$ref: '#/definitions/handlers.AuthResponse'
|
||||||
"400":
|
"400":
|
||||||
description: Invalid request body or missing refresh token
|
description: Invalid request body or missing refresh token
|
||||||
schema:
|
schema:
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -12,6 +12,7 @@ require (
|
|||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
github.com/swaggo/http-swagger v1.3.4
|
github.com/swaggo/http-swagger v1.3.4
|
||||||
github.com/swaggo/swag v1.16.6
|
github.com/swaggo/swag v1.16.6
|
||||||
|
github.com/urfave/cli/v3 v3.6.1
|
||||||
golang.org/x/crypto v0.43.0
|
golang.org/x/crypto v0.43.0
|
||||||
golang.org/x/net v0.46.0
|
golang.org/x/net v0.46.0
|
||||||
gorm.io/driver/postgres v1.6.0
|
gorm.io/driver/postgres v1.6.0
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -116,3 +116,5 @@ gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
|
|||||||
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
|
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
|
||||||
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
|
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
|
||||||
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
|
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
|
||||||
|
github.com/urfave/cli/v3 v3.6.1 h1:j8Qq8NyUawj/7rTYdBGrxcH7A/j7/G8Q5LhWEW4G3Mo=
|
||||||
|
github.com/urfave/cli/v3 v3.6.1/go.mod h1:ysVLtOEmg2tOy6PknnYVhDoouyC/6N42TMeoMzskhso=
|
||||||
|
|||||||
@@ -231,11 +231,9 @@ func TestValidateJWTSecret(t *testing.T) {
|
|||||||
if tt.errorMsg != "" && !strings.Contains(err.Error(), tt.errorMsg) {
|
if tt.errorMsg != "" && !strings.Contains(err.Error(), tt.errorMsg) {
|
||||||
t.Fatalf("expected error message to contain %q, got %q", tt.errorMsg, err.Error())
|
t.Fatalf("expected error message to contain %q, got %q", tt.errorMsg, err.Error())
|
||||||
}
|
}
|
||||||
} else {
|
} else if err != nil {
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error for secret %q: %v", tt.secret, err)
|
t.Fatalf("unexpected error for secret %q: %v", tt.secret, err)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -497,11 +495,9 @@ func TestValidateJWTConfig(t *testing.T) {
|
|||||||
if tt.errorMsg != "" && !strings.Contains(err.Error(), tt.errorMsg) {
|
if tt.errorMsg != "" && !strings.Contains(err.Error(), tt.errorMsg) {
|
||||||
t.Fatalf("expected error message to contain %q, got %q", tt.errorMsg, err.Error())
|
t.Fatalf("expected error message to contain %q, got %q", tt.errorMsg, err.Error())
|
||||||
}
|
}
|
||||||
} else {
|
} else if err != nil {
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error for config %+v: %v", tt.config, err)
|
t.Fatalf("unexpected error for config %+v: %v", tt.config, err)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -626,11 +622,9 @@ func TestLoadWithInvalidJWTConfig(t *testing.T) {
|
|||||||
if tt.errorMsg != "" && !strings.Contains(err.Error(), tt.errorMsg) {
|
if tt.errorMsg != "" && !strings.Contains(err.Error(), tt.errorMsg) {
|
||||||
t.Fatalf("expected error message to contain %q, got %q", tt.errorMsg, err.Error())
|
t.Fatalf("expected error message to contain %q, got %q", tt.errorMsg, err.Error())
|
||||||
}
|
}
|
||||||
} else {
|
} else if err != nil {
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -912,11 +906,9 @@ func TestValidateBcryptCost(t *testing.T) {
|
|||||||
if tt.errorMsg != "" && !strings.Contains(err.Error(), tt.errorMsg) {
|
if tt.errorMsg != "" && !strings.Contains(err.Error(), tt.errorMsg) {
|
||||||
t.Fatalf("expected error message to contain %q, got %q", tt.errorMsg, err.Error())
|
t.Fatalf("expected error message to contain %q, got %q", tt.errorMsg, err.Error())
|
||||||
}
|
}
|
||||||
} else {
|
} else if err != nil {
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error for BCRYPT_COST %d: %v", tt.bcryptCost, err)
|
t.Fatalf("unexpected error for BCRYPT_COST %d: %v", tt.bcryptCost, err)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -976,18 +968,15 @@ func TestLoadWithInvalidBcryptCost(t *testing.T) {
|
|||||||
if tt.errorMsg != "" && !strings.Contains(err.Error(), tt.errorMsg) {
|
if tt.errorMsg != "" && !strings.Contains(err.Error(), tt.errorMsg) {
|
||||||
t.Fatalf("expected error message to contain %q, got %q", tt.errorMsg, err.Error())
|
t.Fatalf("expected error message to contain %q, got %q", tt.errorMsg, err.Error())
|
||||||
}
|
}
|
||||||
} else {
|
} else if err != nil {
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
} else {
|
||||||
expectedCost := 12
|
expectedCost := 12
|
||||||
if tt.bcryptCost == "" {
|
if tt.bcryptCost == "" {
|
||||||
expectedCost = 10
|
expectedCost = 10
|
||||||
} else {
|
} else if costInt, err := strconv.Atoi(tt.bcryptCost); err == nil {
|
||||||
if costInt, err := strconv.Atoi(tt.bcryptCost); err == nil {
|
|
||||||
expectedCost = costInt
|
expectedCost = costInt
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if cfg.App.BcryptCost != expectedCost {
|
if cfg.App.BcryptCost != expectedCost {
|
||||||
t.Fatalf("expected BCRYPT_COST %d, got %d", expectedCost, cfg.App.BcryptCost)
|
t.Fatalf("expected BCRYPT_COST %d, got %d", expectedCost, cfg.App.BcryptCost)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,9 +43,9 @@ type ConfirmAccountDeletionRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type RefreshTokenRequest struct {
|
type RefreshTokenRequest struct {
|
||||||
RefreshToken string `json:"refresh_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." validate:"required"`
|
RefreshToken string `json:"refresh_token" example:"f94d4ddc7d9b4fcb9d3a2c44c400b780c3e1f1a5c2b7d4e6a0b1c2d3e4f5a6b7" validate:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type RevokeTokenRequest struct {
|
type RevokeTokenRequest struct {
|
||||||
RefreshToken string `json:"refresh_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." validate:"required"`
|
RefreshToken string `json:"refresh_token" example:"f94d4ddc7d9b4fcb9d3a2c44c400b780c3e1f1a5c2b7d4e6a0b1c2d3e4f5a6b7" validate:"required"`
|
||||||
}
|
}
|
||||||
|
|||||||
23
internal/dto/auth_response.go
Normal file
23
internal/dto/auth_response.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"goyco/internal/services"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthResponseDTO struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
User UserDTO `json:"user"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func ToAuthResponseDTO(result *services.AuthResult) AuthResponseDTO {
|
||||||
|
if result == nil {
|
||||||
|
return AuthResponseDTO{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return AuthResponseDTO{
|
||||||
|
AccessToken: result.AccessToken,
|
||||||
|
RefreshToken: result.RefreshToken,
|
||||||
|
User: ToUserDTO(result.User),
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -29,6 +29,14 @@ type PostListDTO struct {
|
|||||||
Offset int `json:"offset"`
|
Offset int `json:"offset"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SearchPostListDTO struct {
|
||||||
|
Posts []PostDTO `json:"posts"`
|
||||||
|
Count int `json:"count"`
|
||||||
|
Query string `json:"query"`
|
||||||
|
Limit int `json:"limit"`
|
||||||
|
Offset int `json:"offset"`
|
||||||
|
}
|
||||||
|
|
||||||
func ToPostDTO(post *database.Post) PostDTO {
|
func ToPostDTO(post *database.Post) PostDTO {
|
||||||
if post == nil {
|
if post == nil {
|
||||||
return PostDTO{}
|
return PostDTO{}
|
||||||
@@ -67,3 +75,28 @@ func ToPostDTOs(posts []database.Post) []PostDTO {
|
|||||||
}
|
}
|
||||||
return dtos
|
return dtos
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ToPostListDTO(posts []database.Post, limit, offset int) PostListDTO {
|
||||||
|
postDTOs := ToPostDTOs(posts)
|
||||||
|
return PostListDTO{
|
||||||
|
Posts: postDTOs,
|
||||||
|
Count: len(postDTOs),
|
||||||
|
Limit: limit,
|
||||||
|
Offset: offset,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ToSearchPostListDTO(posts []database.Post, query string, limit, offset int) SearchPostListDTO {
|
||||||
|
postDTOs := ToPostDTOs(posts)
|
||||||
|
return SearchPostListDTO{
|
||||||
|
Posts: postDTOs,
|
||||||
|
Count: len(postDTOs),
|
||||||
|
Query: query,
|
||||||
|
Limit: limit,
|
||||||
|
Offset: offset,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type TitleResponseDTO struct {
|
||||||
|
Title string `json:"title"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -74,3 +74,52 @@ func ToSanitizedUserDTOs(users []database.User) []SanitizedUserDTO {
|
|||||||
}
|
}
|
||||||
return dtos
|
return dtos
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SanitizedUserListDTO struct {
|
||||||
|
Users []SanitizedUserDTO `json:"users"`
|
||||||
|
Count int `json:"count"`
|
||||||
|
Limit int `json:"limit"`
|
||||||
|
Offset int `json:"offset"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func ToUserListDTO(users []database.User, limit, offset int) UserListDTO {
|
||||||
|
userDTOs := ToUserDTOs(users)
|
||||||
|
return UserListDTO{
|
||||||
|
Users: userDTOs,
|
||||||
|
Count: len(userDTOs),
|
||||||
|
Limit: limit,
|
||||||
|
Offset: offset,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ToSanitizedUserListDTO(users []database.User, limit, offset int) SanitizedUserListDTO {
|
||||||
|
userDTOs := ToSanitizedUserDTOs(users)
|
||||||
|
return SanitizedUserListDTO{
|
||||||
|
Users: userDTOs,
|
||||||
|
Count: len(userDTOs),
|
||||||
|
Limit: limit,
|
||||||
|
Offset: offset,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type RegistrationResponseDTO struct {
|
||||||
|
User UserDTO `json:"user"`
|
||||||
|
VerificationSent bool `json:"verification_sent"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func ToRegistrationResponseDTO(user *database.User, verificationSent bool) RegistrationResponseDTO {
|
||||||
|
return RegistrationResponseDTO{
|
||||||
|
User: ToUserDTO(user),
|
||||||
|
VerificationSent: verificationSent,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type AccountDeletionResponseDTO struct {
|
||||||
|
PostsDeleted bool `json:"posts_deleted"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageResponseDTO struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmptyResponseDTO struct{}
|
||||||
|
|||||||
@@ -41,3 +41,14 @@ func ToVoteDTOs(votes []database.Vote) []VoteDTO {
|
|||||||
}
|
}
|
||||||
return dtos
|
return dtos
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type VoteResponseDTO struct {
|
||||||
|
HasVote bool `json:"has_vote"`
|
||||||
|
Vote *VoteDTO `json:"vote,omitempty"`
|
||||||
|
IsAnonymous bool `json:"is_anonymous"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type VoteListDTO struct {
|
||||||
|
Votes []VoteDTO `json:"votes"`
|
||||||
|
Count int `json:"count"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,32 +8,43 @@ import (
|
|||||||
"goyco/internal/testutils"
|
"goyco/internal/testutils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func fetchSwaggerDoc(t *testing.T, ctx *testContext) map[string]any {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
request, err := http.NewRequest("GET", ctx.baseURL+"/swagger/doc.json", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create request: %v", err)
|
||||||
|
}
|
||||||
|
testutils.WithStandardHeaders(request)
|
||||||
|
|
||||||
|
response, err := ctx.client.Do(request)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
if response.StatusCode != http.StatusOK {
|
||||||
|
t.Skipf("Swagger JSON not available (status %d)", response.StatusCode)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var swaggerDoc map[string]any
|
||||||
|
if err := json.NewDecoder(response.Body).Decode(&swaggerDoc); err != nil {
|
||||||
|
t.Fatalf("Failed to decode Swagger JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return swaggerDoc
|
||||||
|
}
|
||||||
|
|
||||||
func TestE2E_SwaggerDocumentation(t *testing.T) {
|
func TestE2E_SwaggerDocumentation(t *testing.T) {
|
||||||
ctx := setupTestContext(t)
|
ctx := setupTestContext(t)
|
||||||
|
|
||||||
t.Run("swagger_json_is_valid", func(t *testing.T) {
|
t.Run("swagger_json_is_valid", func(t *testing.T) {
|
||||||
req, err := http.NewRequest("GET", ctx.baseURL+"/swagger/doc.json", nil)
|
swaggerDoc := fetchSwaggerDoc(t, ctx)
|
||||||
if err != nil {
|
if swaggerDoc == nil {
|
||||||
t.Fatalf("Failed to create request: %v", err)
|
|
||||||
}
|
|
||||||
testutils.WithStandardHeaders(req)
|
|
||||||
|
|
||||||
resp, err := ctx.client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
t.Skipf("Swagger JSON not available (status %d)", resp.StatusCode)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var swaggerDoc map[string]interface{}
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&swaggerDoc); err != nil {
|
|
||||||
t.Fatalf("Failed to decode Swagger JSON: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if swaggerDoc["swagger"] == nil && swaggerDoc["openapi"] == nil {
|
if swaggerDoc["swagger"] == nil && swaggerDoc["openapi"] == nil {
|
||||||
t.Error("Swagger JSON missing swagger/openapi version")
|
t.Error("Swagger JSON missing swagger/openapi version")
|
||||||
}
|
}
|
||||||
@@ -66,32 +77,18 @@ func TestE2E_SwaggerDocumentation(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("api_endpoints_documented", func(t *testing.T) {
|
t.Run("api_endpoints_documented", func(t *testing.T) {
|
||||||
req, err := http.NewRequest("GET", ctx.baseURL+"/swagger/doc.json", nil)
|
swaggerDoc := fetchSwaggerDoc(t, ctx)
|
||||||
if err != nil {
|
if swaggerDoc == nil {
|
||||||
t.Fatalf("Failed to create request: %v", err)
|
|
||||||
}
|
|
||||||
testutils.WithStandardHeaders(req)
|
|
||||||
|
|
||||||
resp, err := ctx.client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
t.Skip("Swagger JSON not available")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var swaggerDoc map[string]interface{}
|
pathsRaw, exists := swaggerDoc["paths"]
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&swaggerDoc); err != nil {
|
if !exists {
|
||||||
t.Fatalf("Failed to decode Swagger JSON: %v", err)
|
t.Fatalf("Swagger doc missing 'paths' section")
|
||||||
}
|
}
|
||||||
|
paths, ok := pathsRaw.(map[string]any)
|
||||||
paths, ok := swaggerDoc["paths"].(map[string]interface{})
|
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Error("Paths section is not a map")
|
t.Fatalf("Swagger doc 'paths' section has invalid type: expected map[string]any, got %T", pathsRaw)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
requiredPaths := []string{
|
requiredPaths := []string{
|
||||||
@@ -99,6 +96,10 @@ func TestE2E_SwaggerDocumentation(t *testing.T) {
|
|||||||
"/api/auth/login",
|
"/api/auth/login",
|
||||||
"/api/auth/register",
|
"/api/auth/register",
|
||||||
"/api/auth/me",
|
"/api/auth/me",
|
||||||
|
"/api/auth/refresh",
|
||||||
|
"/api/auth/revoke",
|
||||||
|
"/api/auth/revoke-all",
|
||||||
|
"/api/auth/logout",
|
||||||
"/api/posts",
|
"/api/posts",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,33 +111,34 @@ func TestE2E_SwaggerDocumentation(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("request_response_schemas_present", func(t *testing.T) {
|
t.Run("request_response_schemas_present", func(t *testing.T) {
|
||||||
req, err := http.NewRequest("GET", ctx.baseURL+"/swagger/doc.json", nil)
|
swaggerDoc := fetchSwaggerDoc(t, ctx)
|
||||||
if err != nil {
|
if swaggerDoc == nil {
|
||||||
t.Fatalf("Failed to create request: %v", err)
|
|
||||||
}
|
|
||||||
testutils.WithStandardHeaders(req)
|
|
||||||
|
|
||||||
resp, err := ctx.client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
t.Skip("Swagger JSON not available")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var swaggerDoc map[string]interface{}
|
var definitions map[string]any
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&swaggerDoc); err != nil {
|
var ok bool
|
||||||
t.Fatalf("Failed to decode Swagger JSON: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
definitions, ok := swaggerDoc["definitions"].(map[string]interface{})
|
definitionsRaw, exists := swaggerDoc["definitions"]
|
||||||
|
if exists {
|
||||||
|
definitions, ok = definitionsRaw.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
definitions, ok = swaggerDoc["components"].(map[string]interface{})
|
t.Fatalf("Swagger doc 'definitions' section has invalid type: expected map[string]any, got %T", definitionsRaw)
|
||||||
if ok {
|
}
|
||||||
definitions, _ = definitions["schemas"].(map[string]interface{})
|
} else {
|
||||||
|
componentsRaw, exists := swaggerDoc["components"]
|
||||||
|
if exists {
|
||||||
|
components, ok := componentsRaw.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Swagger doc 'components' section has invalid type: expected map[string]any, got %T", componentsRaw)
|
||||||
|
}
|
||||||
|
schemasRaw, exists := components["schemas"]
|
||||||
|
if exists {
|
||||||
|
definitions, ok = schemasRaw.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Swagger doc 'components.schemas' section has invalid type: expected map[string]any, got %T", schemasRaw)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -173,40 +175,27 @@ func TestE2E_APIEndpointDocumentation(t *testing.T) {
|
|||||||
ctx := setupTestContext(t)
|
ctx := setupTestContext(t)
|
||||||
|
|
||||||
t.Run("api_info_endpoint_documented", func(t *testing.T) {
|
t.Run("api_info_endpoint_documented", func(t *testing.T) {
|
||||||
req, err := http.NewRequest("GET", ctx.baseURL+"/swagger/doc.json", nil)
|
swaggerDoc := fetchSwaggerDoc(t, ctx)
|
||||||
if err != nil {
|
if swaggerDoc == nil {
|
||||||
t.Fatalf("Failed to create request: %v", err)
|
|
||||||
}
|
|
||||||
testutils.WithStandardHeaders(req)
|
|
||||||
|
|
||||||
resp, err := ctx.client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
t.Skip("Swagger JSON not available")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var swaggerDoc map[string]interface{}
|
pathsRaw, exists := swaggerDoc["paths"]
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&swaggerDoc); err != nil {
|
if !exists {
|
||||||
t.Fatalf("Failed to decode Swagger JSON: %v", err)
|
t.Fatalf("Swagger doc missing 'paths' section")
|
||||||
}
|
}
|
||||||
|
paths, ok := pathsRaw.(map[string]any)
|
||||||
paths, ok := swaggerDoc["paths"].(map[string]interface{})
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
t.Fatalf("Swagger doc 'paths' section has invalid type: expected map[string]any, got %T", pathsRaw)
|
||||||
}
|
}
|
||||||
|
|
||||||
apiPath, ok := paths["/api"].(map[string]interface{})
|
apiPath, ok := paths["/api"].(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Error("API endpoint not documented")
|
t.Error("API endpoint not documented")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
getMethod, ok := apiPath["get"].(map[string]interface{})
|
getMethod, ok := apiPath["get"].(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Error("API GET method not documented")
|
t.Error("API GET method not documented")
|
||||||
return
|
return
|
||||||
@@ -218,46 +207,37 @@ func TestE2E_APIEndpointDocumentation(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("auth_endpoints_documented", func(t *testing.T) {
|
t.Run("auth_endpoints_documented", func(t *testing.T) {
|
||||||
req, err := http.NewRequest("GET", ctx.baseURL+"/swagger/doc.json", nil)
|
swaggerDoc := fetchSwaggerDoc(t, ctx)
|
||||||
if err != nil {
|
if swaggerDoc == nil {
|
||||||
t.Fatalf("Failed to create request: %v", err)
|
|
||||||
}
|
|
||||||
testutils.WithStandardHeaders(req)
|
|
||||||
|
|
||||||
resp, err := ctx.client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
t.Skip("Swagger JSON not available")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var swaggerDoc map[string]interface{}
|
pathsRaw, exists := swaggerDoc["paths"]
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&swaggerDoc); err != nil {
|
if !exists {
|
||||||
t.Fatalf("Failed to decode Swagger JSON: %v", err)
|
t.Fatalf("Swagger doc missing 'paths' section")
|
||||||
}
|
}
|
||||||
|
paths, ok := pathsRaw.(map[string]any)
|
||||||
paths, ok := swaggerDoc["paths"].(map[string]interface{})
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
t.Fatalf("Swagger doc 'paths' section has invalid type: expected map[string]any, got %T", pathsRaw)
|
||||||
}
|
}
|
||||||
|
|
||||||
authEndpoints := []string{
|
authEndpoints := []string{
|
||||||
"/api/auth/login",
|
"/api/auth/login",
|
||||||
"/api/auth/register",
|
"/api/auth/register",
|
||||||
|
"/api/auth/refresh",
|
||||||
|
"/api/auth/revoke",
|
||||||
|
"/api/auth/revoke-all",
|
||||||
|
"/api/auth/logout",
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, endpoint := range authEndpoints {
|
for _, endpoint := range authEndpoints {
|
||||||
endpointData, ok := paths[endpoint].(map[string]interface{})
|
endpointData, ok := paths[endpoint].(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Errorf("Auth endpoint %s not documented", endpoint)
|
t.Errorf("Auth endpoint %s not documented", endpoint)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
postMethod, ok := endpointData["post"].(map[string]interface{})
|
postMethod, ok := endpointData["post"].(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Errorf("Auth endpoint %s missing POST method", endpoint)
|
t.Errorf("Auth endpoint %s missing POST method", endpoint)
|
||||||
continue
|
continue
|
||||||
@@ -267,5 +247,39 @@ func TestE2E_APIEndpointDocumentation(t *testing.T) {
|
|||||||
t.Logf("Auth endpoint %s may use inline request body", endpoint)
|
t.Logf("Auth endpoint %s may use inline request body", endpoint)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
refreshEndpointData, ok := paths["/api/auth/refresh"].(map[string]any)
|
||||||
|
if ok {
|
||||||
|
postMethod, ok := refreshEndpointData["post"].(map[string]any)
|
||||||
|
if ok {
|
||||||
|
responses, ok := postMethod["responses"].(map[string]any)
|
||||||
|
if ok {
|
||||||
|
successResponse, ok := responses["200"].(map[string]any)
|
||||||
|
if ok {
|
||||||
|
content, ok := successResponse["content"].(map[string]any)
|
||||||
|
if ok {
|
||||||
|
applicationJson, ok := content["application/json"].(map[string]any)
|
||||||
|
if ok {
|
||||||
|
schema, ok := applicationJson["schema"].(map[string]any)
|
||||||
|
if ok {
|
||||||
|
properties, ok := schema["properties"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
data, ok := schema["data"].(map[string]any)
|
||||||
|
if ok {
|
||||||
|
properties, ok = data["properties"].(map[string]any)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if properties != nil {
|
||||||
|
if properties["refresh_token"] == nil {
|
||||||
|
t.Error("Refresh endpoint response schema missing refresh_token field (rotation not documented)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -271,8 +271,8 @@ func TestE2E_RefreshTokenFlow(t *testing.T) {
|
|||||||
t.Logf("New access token is identical to original (may occur if generated within same second)")
|
t.Logf("New access token is identical to original (may occur if generated within same second)")
|
||||||
}
|
}
|
||||||
|
|
||||||
if authClient.RefreshToken != originalRefreshToken {
|
if authClient.RefreshToken == originalRefreshToken {
|
||||||
t.Logf("Refresh token was changed (token rotation), which is acceptable")
|
t.Errorf("Expected refresh token to rotate")
|
||||||
}
|
}
|
||||||
|
|
||||||
profile := authClient.GetProfile(t)
|
profile := authClient.GetProfile(t)
|
||||||
|
|||||||
@@ -1,89 +1,35 @@
|
|||||||
package fuzz
|
package fuzz
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
"goyco/internal/database"
|
||||||
|
|
||||||
"gorm.io/driver/sqlite"
|
"gorm.io/driver/sqlite"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
fuzzDBOnce sync.Once
|
|
||||||
fuzzDB *gorm.DB
|
|
||||||
fuzzDBErr error
|
|
||||||
)
|
|
||||||
|
|
||||||
func GetFuzzDB() (*gorm.DB, error) {
|
func GetFuzzDB() (*gorm.DB, error) {
|
||||||
fuzzDBOnce.Do(func() {
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
dbName := "file:memdb_fuzz?mode=memory&cache=shared&_journal_mode=WAL&_synchronous=NORMAL"
|
|
||||||
fuzzDB, fuzzDBErr = gorm.Open(sqlite.Open(dbName), &gorm.Config{
|
|
||||||
Logger: logger.Default.LogMode(logger.Silent),
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
})
|
})
|
||||||
if fuzzDBErr == nil {
|
if err != nil {
|
||||||
fuzzDBErr = fuzzDB.Exec(`
|
return nil, err
|
||||||
CREATE TABLE IF NOT EXISTS users (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
username TEXT UNIQUE NOT NULL,
|
|
||||||
email TEXT UNIQUE NOT NULL,
|
|
||||||
password TEXT NOT NULL,
|
|
||||||
email_verified INTEGER DEFAULT 0 NOT NULL,
|
|
||||||
email_verified_at DATETIME,
|
|
||||||
email_verification_token TEXT,
|
|
||||||
email_verification_sent_at DATETIME,
|
|
||||||
password_reset_token TEXT,
|
|
||||||
password_reset_sent_at DATETIME,
|
|
||||||
password_reset_expires_at DATETIME,
|
|
||||||
locked INTEGER DEFAULT 0,
|
|
||||||
session_version INTEGER DEFAULT 1 NOT NULL,
|
|
||||||
created_at DATETIME,
|
|
||||||
updated_at DATETIME,
|
|
||||||
deleted_at DATETIME
|
|
||||||
);
|
|
||||||
CREATE TABLE IF NOT EXISTS posts (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
title TEXT NOT NULL,
|
|
||||||
url TEXT UNIQUE,
|
|
||||||
content TEXT,
|
|
||||||
author_id INTEGER,
|
|
||||||
author_name TEXT,
|
|
||||||
up_votes INTEGER DEFAULT 0,
|
|
||||||
down_votes INTEGER DEFAULT 0,
|
|
||||||
score INTEGER DEFAULT 0,
|
|
||||||
created_at DATETIME,
|
|
||||||
updated_at DATETIME,
|
|
||||||
deleted_at DATETIME,
|
|
||||||
FOREIGN KEY(author_id) REFERENCES users(id)
|
|
||||||
);
|
|
||||||
CREATE TABLE IF NOT EXISTS votes (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
user_id INTEGER,
|
|
||||||
post_id INTEGER NOT NULL,
|
|
||||||
type TEXT NOT NULL,
|
|
||||||
vote_hash TEXT,
|
|
||||||
created_at DATETIME,
|
|
||||||
updated_at DATETIME,
|
|
||||||
FOREIGN KEY(user_id) REFERENCES users(id),
|
|
||||||
FOREIGN KEY(post_id) REFERENCES posts(id)
|
|
||||||
);
|
|
||||||
CREATE TABLE IF NOT EXISTS account_deletion_requests (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
user_id INTEGER NOT NULL,
|
|
||||||
token_hash TEXT UNIQUE NOT NULL,
|
|
||||||
expires_at DATETIME NOT NULL,
|
|
||||||
created_at DATETIME,
|
|
||||||
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
||||||
);
|
|
||||||
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
user_id INTEGER NOT NULL,
|
|
||||||
token_hash TEXT UNIQUE NOT NULL,
|
|
||||||
expires_at DATETIME NOT NULL,
|
|
||||||
created_at DATETIME,
|
|
||||||
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
||||||
);
|
|
||||||
`).Error
|
|
||||||
}
|
}
|
||||||
})
|
|
||||||
return fuzzDB, fuzzDBErr
|
if execErr := db.Exec("PRAGMA foreign_keys = ON").Error; execErr != nil {
|
||||||
|
return nil, execErr
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.AutoMigrate(
|
||||||
|
&database.User{},
|
||||||
|
&database.Post{},
|
||||||
|
&database.Vote{},
|
||||||
|
&database.AccountDeletionRequest{},
|
||||||
|
&database.RefreshToken{},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1704,8 +1704,11 @@ func TestGetFuzzDB(t *testing.T) {
|
|||||||
if err2 != nil {
|
if err2 != nil {
|
||||||
t.Fatalf("Second GetFuzzDB call failed: %v", err2)
|
t.Fatalf("Second GetFuzzDB call failed: %v", err2)
|
||||||
}
|
}
|
||||||
if db2 != db {
|
if db2 == nil {
|
||||||
t.Fatal("GetFuzzDB should return the same database instance")
|
t.Fatal("Second GetFuzzDB returned nil database")
|
||||||
|
}
|
||||||
|
if db2 == db {
|
||||||
|
t.Fatal("GetFuzzDB should return a new database instance for each call")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ func (h *APIHandler) GetAPIInfo(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
apiInfo := map[string]any{
|
apiInfo := map[string]any{
|
||||||
"name": fmt.Sprintf("%s API", h.config.App.Title),
|
"name": fmt.Sprintf("%s API", h.config.App.Title),
|
||||||
"version": version.Version,
|
"version": version.GetVersion(),
|
||||||
"description": "Y Combinator-style news board API",
|
"description": "Y Combinator-style news board API",
|
||||||
"endpoints": map[string]any{
|
"endpoints": map[string]any{
|
||||||
"authentication": map[string]any{
|
"authentication": map[string]any{
|
||||||
@@ -145,7 +145,7 @@ func (h *APIHandler) GetHealth(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
if h.healthChecker != nil {
|
if h.healthChecker != nil {
|
||||||
health := h.healthChecker.CheckHealth()
|
health := h.healthChecker.CheckHealth()
|
||||||
health["version"] = version.Version
|
health["version"] = version.GetVersion()
|
||||||
SendSuccessResponse(w, "Health check successful", health)
|
SendSuccessResponse(w, "Health check successful", health)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -155,7 +155,7 @@ func (h *APIHandler) GetHealth(w http.ResponseWriter, r *http.Request) {
|
|||||||
health := map[string]any{
|
health := map[string]any{
|
||||||
"status": "healthy",
|
"status": "healthy",
|
||||||
"timestamp": currentTimestamp,
|
"timestamp": currentTimestamp,
|
||||||
"version": version.Version,
|
"version": version.GetVersion(),
|
||||||
"services": map[string]any{
|
"services": map[string]any{
|
||||||
"database": "connected",
|
"database": "connected",
|
||||||
"api": "running",
|
"api": "running",
|
||||||
@@ -230,7 +230,7 @@ func (h *APIHandler) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
|||||||
},
|
},
|
||||||
"system": map[string]any{
|
"system": map[string]any{
|
||||||
"timestamp": time.Now().UTC().Format(time.RFC3339),
|
"timestamp": time.Now().UTC().Format(time.RFC3339),
|
||||||
"version": version.Version,
|
"version": version.GetVersion(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -44,26 +44,6 @@ type AuthHandler struct {
|
|||||||
|
|
||||||
type AuthResponse = CommonResponse
|
type AuthResponse = CommonResponse
|
||||||
|
|
||||||
type AuthTokensResponse struct {
|
|
||||||
Success bool `json:"success" example:"true"`
|
|
||||||
Message string `json:"message" example:"Authentication successful"`
|
|
||||||
Data AuthTokensDetail `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AuthTokensDetail struct {
|
|
||||||
AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
|
|
||||||
RefreshToken string `json:"refresh_token" example:"f94d4ddc7d9b4fcb9d3a2c44c400b780"`
|
|
||||||
User AuthUserSummary `json:"user"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AuthUserSummary struct {
|
|
||||||
ID uint `json:"id" example:"42"`
|
|
||||||
Username string `json:"username" example:"janedoe"`
|
|
||||||
Email string `json:"email" example:"jane@example.com"`
|
|
||||||
EmailVerified bool `json:"email_verified" example:"true"`
|
|
||||||
Locked bool `json:"locked" example:"false"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAuthHandler(authService AuthServiceInterface, userRepo repositories.UserRepository) *AuthHandler {
|
func NewAuthHandler(authService AuthServiceInterface, userRepo repositories.UserRepository) *AuthHandler {
|
||||||
return &AuthHandler{
|
return &AuthHandler{
|
||||||
authService: authService,
|
authService: authService,
|
||||||
@@ -77,28 +57,28 @@ func NewAuthHandler(authService AuthServiceInterface, userRepo repositories.User
|
|||||||
// @Accept json
|
// @Accept json
|
||||||
// @Produce json
|
// @Produce json
|
||||||
// @Param request body dto.LoginRequest true "Login credentials"
|
// @Param request body dto.LoginRequest true "Login credentials"
|
||||||
// @Success 200 {object} AuthTokensResponse "Authentication successful"
|
// @Success 200 {object} AuthResponse "Authentication successful"
|
||||||
// @Failure 400 {object} AuthResponse "Invalid request data or validation failed"
|
// @Failure 400 {object} AuthResponse "Invalid request data or validation failed"
|
||||||
// @Failure 401 {object} AuthResponse "Invalid credentials"
|
// @Failure 401 {object} AuthResponse "Invalid credentials"
|
||||||
// @Failure 403 {object} AuthResponse "Account is locked"
|
// @Failure 403 {object} AuthResponse "Account is locked"
|
||||||
// @Failure 500 {object} AuthResponse "Internal server error"
|
// @Failure 500 {object} AuthResponse "Internal server error"
|
||||||
// @Router /api/auth/login [post]
|
// @Router /api/auth/login [post]
|
||||||
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
|
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||||
req, ok := GetValidatedDTO[dto.LoginRequest](r)
|
request, ok := GetValidatedDTO[dto.LoginRequest](w, r)
|
||||||
if !ok {
|
if !ok {
|
||||||
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
username := security.SanitizeUsername(req.Username)
|
username := security.SanitizeUsername(request.Username)
|
||||||
password := strings.TrimSpace(req.Password)
|
password := strings.TrimSpace(request.Password)
|
||||||
|
|
||||||
result, err := h.authService.Login(username, password)
|
result, err := h.authService.Login(username, password)
|
||||||
if !HandleServiceError(w, err, "Authentication failed", http.StatusInternalServerError) {
|
if !HandleServiceError(w, err, "Authentication failed", http.StatusInternalServerError) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
SendSuccessResponse(w, "Authentication successful", result)
|
responseDTO := dto.ToAuthResponseDTO(result)
|
||||||
|
SendSuccessResponse(w, "Authentication successful", responseDTO)
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Summary Register a new user
|
// @Summary Register a new user
|
||||||
@@ -113,31 +93,16 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
|
|||||||
// @Failure 500 {object} AuthResponse "Internal server error"
|
// @Failure 500 {object} AuthResponse "Internal server error"
|
||||||
// @Router /api/auth/register [post]
|
// @Router /api/auth/register [post]
|
||||||
func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
|
func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
|
||||||
req, ok := GetValidatedDTO[dto.RegisterRequest](r)
|
request, ok := GetValidatedDTO[dto.RegisterRequest](w, r)
|
||||||
if !ok {
|
if !ok {
|
||||||
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
username := strings.TrimSpace(req.Username)
|
username := strings.TrimSpace(request.Username)
|
||||||
email := strings.TrimSpace(req.Email)
|
email := strings.TrimSpace(request.Email)
|
||||||
password := strings.TrimSpace(req.Password)
|
password := strings.TrimSpace(request.Password)
|
||||||
|
|
||||||
username = security.SanitizeUsername(username)
|
username = security.SanitizeUsername(username)
|
||||||
if err := validation.ValidateUsername(username); err != nil {
|
|
||||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := validation.ValidateEmail(email); err != nil {
|
|
||||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := validation.ValidatePassword(password); err != nil {
|
|
||||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := h.authService.Register(username, email, password)
|
result, err := h.authService.Register(username, email, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -151,22 +116,8 @@ func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
userData := map[string]any{
|
responseDTO := dto.ToRegistrationResponseDTO(result.User, result.VerificationSent)
|
||||||
"id": result.User.ID,
|
SendCreatedResponse(w, "Registration successful. Check your email to confirm your account.", responseDTO)
|
||||||
"username": result.User.Username,
|
|
||||||
"email": result.User.Email,
|
|
||||||
"email_verified": result.User.EmailVerified,
|
|
||||||
"created_at": result.User.CreatedAt,
|
|
||||||
"updated_at": result.User.UpdatedAt,
|
|
||||||
"deleted_at": result.User.DeletedAt,
|
|
||||||
}
|
|
||||||
|
|
||||||
responseData := map[string]any{
|
|
||||||
"user": userData,
|
|
||||||
"verification_sent": result.VerificationSent,
|
|
||||||
}
|
|
||||||
|
|
||||||
SendCreatedResponse(w, "Registration successful. Check your email to confirm your account.", responseData)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Summary Confirm email address
|
// @Summary Confirm email address
|
||||||
@@ -192,9 +143,7 @@ func (h *AuthHandler) ConfirmEmail(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
userDTO := dto.ToUserDTO(user)
|
userDTO := dto.ToUserDTO(user)
|
||||||
SendSuccessResponse(w, "Email confirmed successfully", map[string]any{
|
SendSuccessResponse(w, "Email confirmed successfully", userDTO)
|
||||||
"user": userDTO,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Summary Resend verification email
|
// @Summary Resend verification email
|
||||||
@@ -212,13 +161,12 @@ func (h *AuthHandler) ConfirmEmail(w http.ResponseWriter, r *http.Request) {
|
|||||||
// @Failure 500 {object} AuthResponse
|
// @Failure 500 {object} AuthResponse
|
||||||
// @Router /api/auth/resend-verification [post]
|
// @Router /api/auth/resend-verification [post]
|
||||||
func (h *AuthHandler) ResendVerificationEmail(w http.ResponseWriter, r *http.Request) {
|
func (h *AuthHandler) ResendVerificationEmail(w http.ResponseWriter, r *http.Request) {
|
||||||
req, ok := GetValidatedDTO[dto.ResendVerificationRequest](r)
|
request, ok := GetValidatedDTO[dto.ResendVerificationRequest](w, r)
|
||||||
if !ok {
|
if !ok {
|
||||||
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
email := strings.TrimSpace(req.Email)
|
email := strings.TrimSpace(request.Email)
|
||||||
|
|
||||||
if email == "" {
|
if email == "" {
|
||||||
SendErrorResponse(w, "Email address is required", http.StatusBadRequest)
|
SendErrorResponse(w, "Email address is required", http.StatusBadRequest)
|
||||||
@@ -244,9 +192,10 @@ func (h *AuthHandler) ResendVerificationEmail(w http.ResponseWriter, r *http.Req
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
SendSuccessResponse(w, "Verification email sent successfully", map[string]any{
|
responseDTO := dto.MessageResponseDTO{
|
||||||
"message": "Check your inbox for the verification link",
|
Message: "Check your inbox for the verification link",
|
||||||
})
|
}
|
||||||
|
SendSuccessResponse(w, "Verification email sent successfully", responseDTO)
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Summary Get current user profile
|
// @Summary Get current user profile
|
||||||
@@ -285,13 +234,12 @@ func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) {
|
|||||||
// @Failure 400 {object} AuthResponse "Invalid request data"
|
// @Failure 400 {object} AuthResponse "Invalid request data"
|
||||||
// @Router /api/auth/forgot-password [post]
|
// @Router /api/auth/forgot-password [post]
|
||||||
func (h *AuthHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Request) {
|
func (h *AuthHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Request) {
|
||||||
req, ok := GetValidatedDTO[dto.ForgotPasswordRequest](r)
|
request, ok := GetValidatedDTO[dto.ForgotPasswordRequest](w, r)
|
||||||
if !ok {
|
if !ok {
|
||||||
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
usernameOrEmail := strings.TrimSpace(req.UsernameOrEmail)
|
usernameOrEmail := strings.TrimSpace(request.UsernameOrEmail)
|
||||||
|
|
||||||
if usernameOrEmail == "" {
|
if usernameOrEmail == "" {
|
||||||
SendErrorResponse(w, "Username or email is required", http.StatusBadRequest)
|
SendErrorResponse(w, "Username or email is required", http.StatusBadRequest)
|
||||||
@@ -301,7 +249,7 @@ func (h *AuthHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Reques
|
|||||||
if err := h.authService.RequestPasswordReset(usernameOrEmail); err != nil {
|
if err := h.authService.RequestPasswordReset(usernameOrEmail); err != nil {
|
||||||
}
|
}
|
||||||
|
|
||||||
SendSuccessResponse(w, "If an account with that username or email exists, we've sent a password reset link.", nil)
|
SendSuccessResponse(w, "If an account with that username or email exists, we've sent a password reset link.", dto.EmptyResponseDTO{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Summary Reset password
|
// @Summary Reset password
|
||||||
@@ -315,25 +263,19 @@ func (h *AuthHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Reques
|
|||||||
// @Failure 500 {object} AuthResponse "Internal server error"
|
// @Failure 500 {object} AuthResponse "Internal server error"
|
||||||
// @Router /api/auth/reset-password [post]
|
// @Router /api/auth/reset-password [post]
|
||||||
func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
|
func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
|
||||||
req, ok := GetValidatedDTO[dto.ResetPasswordRequest](r)
|
request, ok := GetValidatedDTO[dto.ResetPasswordRequest](w, r)
|
||||||
if !ok {
|
if !ok {
|
||||||
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
token := strings.TrimSpace(req.Token)
|
token := strings.TrimSpace(request.Token)
|
||||||
newPassword := strings.TrimSpace(req.NewPassword)
|
newPassword := strings.TrimSpace(request.NewPassword)
|
||||||
|
|
||||||
if token == "" {
|
if token == "" {
|
||||||
SendErrorResponse(w, "Token is required", http.StatusBadRequest)
|
SendErrorResponse(w, "Token is required", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := validation.ValidatePassword(newPassword); err != nil {
|
|
||||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.authService.ResetPassword(token, newPassword); err != nil {
|
if err := h.authService.ResetPassword(token, newPassword); err != nil {
|
||||||
switch {
|
switch {
|
||||||
case strings.Contains(err.Error(), "expired"):
|
case strings.Contains(err.Error(), "expired"):
|
||||||
@@ -346,7 +288,7 @@ func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
SendSuccessResponse(w, "Password reset successfully. You can now sign in with your new password.", nil)
|
SendSuccessResponse(w, "Password reset successfully. You can now sign in with your new password.", dto.EmptyResponseDTO{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Summary Update email address
|
// @Summary Update email address
|
||||||
@@ -369,17 +311,12 @@ func (h *AuthHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req, ok := GetValidatedDTO[dto.UpdateEmailRequest](r)
|
request, ok := GetValidatedDTO[dto.UpdateEmailRequest](w, r)
|
||||||
if !ok {
|
if !ok {
|
||||||
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
email := strings.TrimSpace(req.Email)
|
email := strings.TrimSpace(request.Email)
|
||||||
if err := validation.ValidateEmail(email); err != nil {
|
|
||||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := h.authService.UpdateEmail(userID, email)
|
user, err := h.authService.UpdateEmail(userID, email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -397,9 +334,7 @@ func (h *AuthHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
userDTO := dto.ToUserDTO(user)
|
userDTO := dto.ToUserDTO(user)
|
||||||
SendSuccessResponse(w, "Email updated. Check your inbox to confirm the new address.", map[string]any{
|
SendSuccessResponse(w, "Email updated. Check your inbox to confirm the new address.", userDTO)
|
||||||
"user": userDTO,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Summary Update username
|
// @Summary Update username
|
||||||
@@ -421,17 +356,12 @@ func (h *AuthHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req, ok := GetValidatedDTO[dto.UpdateUsernameRequest](r)
|
request, ok := GetValidatedDTO[dto.UpdateUsernameRequest](w, r)
|
||||||
if !ok {
|
if !ok {
|
||||||
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
username := strings.TrimSpace(req.Username)
|
username := strings.TrimSpace(request.Username)
|
||||||
if err := validation.ValidateUsername(username); err != nil {
|
|
||||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := h.authService.UpdateUsername(userID, username)
|
user, err := h.authService.UpdateUsername(userID, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -445,9 +375,7 @@ func (h *AuthHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
userDTO := dto.ToUserDTO(user)
|
userDTO := dto.ToUserDTO(user)
|
||||||
SendSuccessResponse(w, "Username updated successfully.", map[string]any{
|
SendSuccessResponse(w, "Username updated successfully.", userDTO)
|
||||||
"user": userDTO,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Summary Update password
|
// @Summary Update password
|
||||||
@@ -468,24 +396,13 @@ func (h *AuthHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req, ok := GetValidatedDTO[dto.UpdatePasswordRequest](r)
|
request, ok := GetValidatedDTO[dto.UpdatePasswordRequest](w, r)
|
||||||
if !ok {
|
if !ok {
|
||||||
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
currentPassword := strings.TrimSpace(req.CurrentPassword)
|
currentPassword := strings.TrimSpace(request.CurrentPassword)
|
||||||
newPassword := strings.TrimSpace(req.NewPassword)
|
newPassword := strings.TrimSpace(request.NewPassword)
|
||||||
|
|
||||||
if currentPassword == "" {
|
|
||||||
SendErrorResponse(w, "Current password is required", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := validation.ValidatePassword(newPassword); err != nil {
|
|
||||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := h.authService.UpdatePassword(userID, currentPassword, newPassword)
|
user, err := h.authService.UpdatePassword(userID, currentPassword, newPassword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -498,9 +415,7 @@ func (h *AuthHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
userDTO := dto.ToUserDTO(user)
|
userDTO := dto.ToUserDTO(user)
|
||||||
SendSuccessResponse(w, "Password updated successfully.", map[string]any{
|
SendSuccessResponse(w, "Password updated successfully.", userDTO)
|
||||||
"user": userDTO,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Summary Request account deletion
|
// @Summary Request account deletion
|
||||||
@@ -530,7 +445,7 @@ func (h *AuthHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
SendSuccessResponse(w, "Check your inbox for a confirmation link to finish deleting your account.", nil)
|
SendSuccessResponse(w, "Check your inbox for a confirmation link to finish deleting your account.", dto.EmptyResponseDTO{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Summary Confirm account deletion
|
// @Summary Confirm account deletion
|
||||||
@@ -545,38 +460,39 @@ func (h *AuthHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
|
|||||||
// @Failure 500 {object} AuthResponse "Internal server error"
|
// @Failure 500 {object} AuthResponse "Internal server error"
|
||||||
// @Router /api/auth/account/confirm [post]
|
// @Router /api/auth/account/confirm [post]
|
||||||
func (h *AuthHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Request) {
|
func (h *AuthHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Request) {
|
||||||
req, ok := GetValidatedDTO[dto.ConfirmAccountDeletionRequest](r)
|
request, ok := GetValidatedDTO[dto.ConfirmAccountDeletionRequest](w, r)
|
||||||
if !ok {
|
if !ok {
|
||||||
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
token := strings.TrimSpace(req.Token)
|
token := strings.TrimSpace(request.Token)
|
||||||
|
|
||||||
if token == "" {
|
if token == "" {
|
||||||
SendErrorResponse(w, "Deletion token is required", http.StatusBadRequest)
|
SendErrorResponse(w, "Deletion token is required", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.authService.ConfirmAccountDeletionWithPosts(token, req.DeletePosts); err != nil {
|
if err := h.authService.ConfirmAccountDeletionWithPosts(token, request.DeletePosts); err != nil {
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, services.ErrInvalidDeletionToken):
|
case errors.Is(err, services.ErrInvalidDeletionToken):
|
||||||
SendErrorResponse(w, "This deletion link is invalid or has expired.", http.StatusBadRequest)
|
SendErrorResponse(w, "This deletion link is invalid or has expired.", http.StatusBadRequest)
|
||||||
case errors.Is(err, services.ErrEmailSenderUnavailable):
|
case errors.Is(err, services.ErrEmailSenderUnavailable):
|
||||||
SendErrorResponse(w, "Account deletion isn't available right now because email delivery is disabled.", http.StatusServiceUnavailable)
|
SendErrorResponse(w, "Account deletion isn't available right now because email delivery is disabled.", http.StatusServiceUnavailable)
|
||||||
case errors.Is(err, services.ErrDeletionEmailFailed):
|
case errors.Is(err, services.ErrDeletionEmailFailed):
|
||||||
SendSuccessResponse(w, "Your account has been deleted, but we couldn't send the confirmation email.", map[string]any{
|
responseDTO := dto.AccountDeletionResponseDTO{
|
||||||
"posts_deleted": req.DeletePosts,
|
PostsDeleted: request.DeletePosts,
|
||||||
})
|
}
|
||||||
|
SendSuccessResponse(w, "Your account has been deleted, but we couldn't send the confirmation email.", responseDTO)
|
||||||
default:
|
default:
|
||||||
SendErrorResponse(w, "We couldn't confirm the deletion right now.", http.StatusInternalServerError)
|
SendErrorResponse(w, "We couldn't confirm the deletion right now.", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
SendSuccessResponse(w, "Your account has been deleted.", map[string]any{
|
responseDTO := dto.AccountDeletionResponseDTO{
|
||||||
"posts_deleted": req.DeletePosts,
|
PostsDeleted: request.DeletePosts,
|
||||||
})
|
}
|
||||||
|
SendSuccessResponse(w, "Your account has been deleted.", responseDTO)
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Summary Logout user
|
// @Summary Logout user
|
||||||
@@ -589,39 +505,39 @@ func (h *AuthHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Requ
|
|||||||
// @Failure 401 {object} AuthResponse "Authentication required"
|
// @Failure 401 {object} AuthResponse "Authentication required"
|
||||||
// @Router /api/auth/logout [post]
|
// @Router /api/auth/logout [post]
|
||||||
func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
|
func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||||
SendSuccessResponse(w, "Logged out successfully", nil)
|
SendSuccessResponse(w, "Logged out successfully", dto.EmptyResponseDTO{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Summary Refresh access token
|
// @Summary Refresh access token
|
||||||
// @Description Use a refresh token to get a new access token. This endpoint allows clients to obtain a new access token using a valid refresh token without requiring user credentials.
|
// @Description Use a refresh token to get a new access token. The refresh token is rotated on success, and the previous refresh token becomes invalid.
|
||||||
// @Tags auth
|
// @Tags auth
|
||||||
// @Accept json
|
// @Accept json
|
||||||
// @Produce json
|
// @Produce json
|
||||||
// @Param request body dto.RefreshTokenRequest true "Refresh token data"
|
// @Param request body dto.RefreshTokenRequest true "Refresh token data"
|
||||||
// @Success 200 {object} AuthTokensResponse "Token refreshed successfully"
|
// @Success 200 {object} AuthResponse "Token refreshed successfully"
|
||||||
// @Failure 400 {object} AuthResponse "Invalid request body or missing refresh token"
|
// @Failure 400 {object} AuthResponse "Invalid request body or missing refresh token"
|
||||||
// @Failure 401 {object} AuthResponse "Invalid or expired refresh token"
|
// @Failure 401 {object} AuthResponse "Invalid or expired refresh token"
|
||||||
// @Failure 403 {object} AuthResponse "Account is locked"
|
// @Failure 403 {object} AuthResponse "Account is locked"
|
||||||
// @Failure 500 {object} AuthResponse "Internal server error"
|
// @Failure 500 {object} AuthResponse "Internal server error"
|
||||||
// @Router /api/auth/refresh [post]
|
// @Router /api/auth/refresh [post]
|
||||||
func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
|
func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
|
||||||
req, ok := GetValidatedDTO[dto.RefreshTokenRequest](r)
|
request, ok := GetValidatedDTO[dto.RefreshTokenRequest](w, r)
|
||||||
if !ok {
|
if !ok {
|
||||||
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.RefreshToken == "" {
|
if request.RefreshToken == "" {
|
||||||
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
|
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := h.authService.RefreshAccessToken(req.RefreshToken)
|
result, err := h.authService.RefreshAccessToken(request.RefreshToken)
|
||||||
if !HandleServiceError(w, err, "Token refresh failed", http.StatusInternalServerError) {
|
if !HandleServiceError(w, err, "Token refresh failed", http.StatusInternalServerError) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
SendSuccessResponse(w, "Token refreshed successfully", result)
|
responseDTO := dto.ToAuthResponseDTO(result)
|
||||||
|
SendSuccessResponse(w, "Token refreshed successfully", responseDTO)
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Summary Revoke refresh token
|
// @Summary Revoke refresh token
|
||||||
@@ -637,24 +553,23 @@ func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
// @Failure 500 {object} AuthResponse "Internal server error"
|
// @Failure 500 {object} AuthResponse "Internal server error"
|
||||||
// @Router /api/auth/revoke [post]
|
// @Router /api/auth/revoke [post]
|
||||||
func (h *AuthHandler) RevokeToken(w http.ResponseWriter, r *http.Request) {
|
func (h *AuthHandler) RevokeToken(w http.ResponseWriter, r *http.Request) {
|
||||||
req, ok := GetValidatedDTO[dto.RevokeTokenRequest](r)
|
request, ok := GetValidatedDTO[dto.RevokeTokenRequest](w, r)
|
||||||
if !ok {
|
if !ok {
|
||||||
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.RefreshToken == "" {
|
if request.RefreshToken == "" {
|
||||||
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
|
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err := h.authService.RevokeRefreshToken(req.RefreshToken)
|
err := h.authService.RevokeRefreshToken(request.RefreshToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
SendErrorResponse(w, "Failed to revoke token", http.StatusInternalServerError)
|
SendErrorResponse(w, "Failed to revoke token", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
SendSuccessResponse(w, "Token revoked successfully", nil)
|
SendSuccessResponse(w, "Token revoked successfully", dto.EmptyResponseDTO{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Summary Revoke all user tokens
|
// @Summary Revoke all user tokens
|
||||||
@@ -679,7 +594,7 @@ func (h *AuthHandler) RevokeAllTokens(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
SendSuccessResponse(w, "All tokens revoked successfully", nil)
|
SendSuccessResponse(w, "All tokens revoked successfully", dto.EmptyResponseDTO{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *AuthHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
|
func (h *AuthHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
|
||||||
|
|||||||
@@ -693,7 +693,7 @@ func TestAuthHandlerUpdateEmail(t *testing.T) {
|
|||||||
userID: 1,
|
userID: 1,
|
||||||
mockSetup: func(repo *testutils.UserRepositoryStub) {},
|
mockSetup: func(repo *testutils.UserRepositoryStub) {},
|
||||||
expectedStatus: http.StatusBadRequest,
|
expectedStatus: http.StatusBadRequest,
|
||||||
expectedError: "Invalid request",
|
expectedError: "Invalid JSON",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty email",
|
name: "empty email",
|
||||||
@@ -701,7 +701,7 @@ func TestAuthHandlerUpdateEmail(t *testing.T) {
|
|||||||
userID: 1,
|
userID: 1,
|
||||||
mockSetup: func(repo *testutils.UserRepositoryStub) {},
|
mockSetup: func(repo *testutils.UserRepositoryStub) {},
|
||||||
expectedStatus: http.StatusBadRequest,
|
expectedStatus: http.StatusBadRequest,
|
||||||
expectedError: "Email is required",
|
expectedError: "email is required",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "email already taken",
|
name: "email already taken",
|
||||||
@@ -788,7 +788,7 @@ func TestAuthHandlerUpdateUsername(t *testing.T) {
|
|||||||
userID: 1,
|
userID: 1,
|
||||||
mockSetup: func(repo *testutils.UserRepositoryStub) {},
|
mockSetup: func(repo *testutils.UserRepositoryStub) {},
|
||||||
expectedStatus: http.StatusBadRequest,
|
expectedStatus: http.StatusBadRequest,
|
||||||
expectedError: "Username is required",
|
expectedError: "username is required",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "username already taken",
|
name: "username already taken",
|
||||||
@@ -876,7 +876,7 @@ func TestAuthHandlerUpdatePassword(t *testing.T) {
|
|||||||
userID: 1,
|
userID: 1,
|
||||||
mockSetup: func(repo *testutils.UserRepositoryStub) {},
|
mockSetup: func(repo *testutils.UserRepositoryStub) {},
|
||||||
expectedStatus: http.StatusBadRequest,
|
expectedStatus: http.StatusBadRequest,
|
||||||
expectedError: "Current password is required",
|
expectedError: "current_password is required",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty new password",
|
name: "empty new password",
|
||||||
@@ -884,7 +884,7 @@ func TestAuthHandlerUpdatePassword(t *testing.T) {
|
|||||||
userID: 1,
|
userID: 1,
|
||||||
mockSetup: func(repo *testutils.UserRepositoryStub) {},
|
mockSetup: func(repo *testutils.UserRepositoryStub) {},
|
||||||
expectedStatus: http.StatusBadRequest,
|
expectedStatus: http.StatusBadRequest,
|
||||||
expectedError: "Password is required",
|
expectedError: "new_password is required",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "short new password",
|
name: "short new password",
|
||||||
@@ -892,7 +892,7 @@ func TestAuthHandlerUpdatePassword(t *testing.T) {
|
|||||||
userID: 1,
|
userID: 1,
|
||||||
mockSetup: func(repo *testutils.UserRepositoryStub) {},
|
mockSetup: func(repo *testutils.UserRepositoryStub) {},
|
||||||
expectedStatus: http.StatusBadRequest,
|
expectedStatus: http.StatusBadRequest,
|
||||||
expectedError: "Password must be at least 8 characters long",
|
expectedError: "new_password must be at least 8 characters",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "incorrect current password",
|
name: "incorrect current password",
|
||||||
@@ -1042,13 +1042,13 @@ func TestAuthHandlerResendVerificationEmail(t *testing.T) {
|
|||||||
name: "invalid json",
|
name: "invalid json",
|
||||||
body: "not-json",
|
body: "not-json",
|
||||||
expectedStatus: http.StatusBadRequest,
|
expectedStatus: http.StatusBadRequest,
|
||||||
expectedError: "Invalid request",
|
expectedError: "Invalid JSON",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "missing email",
|
name: "missing email",
|
||||||
body: `{}`,
|
body: `{}`,
|
||||||
expectedStatus: http.StatusBadRequest,
|
expectedStatus: http.StatusBadRequest,
|
||||||
expectedError: "Email address is required",
|
expectedError: "email is required",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "account not found",
|
name: "account not found",
|
||||||
@@ -1167,13 +1167,13 @@ func TestAuthHandlerConfirmAccountDeletion(t *testing.T) {
|
|||||||
name: "invalid json",
|
name: "invalid json",
|
||||||
body: "not-json",
|
body: "not-json",
|
||||||
expectedStatus: http.StatusBadRequest,
|
expectedStatus: http.StatusBadRequest,
|
||||||
expectedError: "Invalid request",
|
expectedError: "Invalid JSON",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "missing token",
|
name: "missing token",
|
||||||
body: `{}`,
|
body: `{}`,
|
||||||
expectedStatus: http.StatusBadRequest,
|
expectedStatus: http.StatusBadRequest,
|
||||||
expectedError: "Deletion token is required",
|
expectedError: "token is required",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid token from service",
|
name: "invalid token from service",
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"goyco/internal/dto"
|
"goyco/internal/dto"
|
||||||
"goyco/internal/middleware"
|
"goyco/internal/middleware"
|
||||||
"goyco/internal/services"
|
"goyco/internal/services"
|
||||||
|
"goyco/internal/validation"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -272,13 +273,51 @@ func HandleServiceError(w http.ResponseWriter, err error, defaultMsg string, def
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetValidatedDTO[T any](r *http.Request) (*T, bool) {
|
func GetValidatedDTO[T any](w http.ResponseWriter, r *http.Request) (*T, bool) {
|
||||||
dtoVal := middleware.GetValidatedDTOFromContext(r.Context())
|
dtoVal := middleware.GetValidatedDTOFromContext(r.Context())
|
||||||
if dtoVal == nil {
|
dtoTypeInContext := middleware.GetDTOTypeFromContext(r.Context())
|
||||||
|
|
||||||
|
var dto *T
|
||||||
|
needsValidation := false
|
||||||
|
|
||||||
|
if dtoVal != nil {
|
||||||
|
var ok bool
|
||||||
|
dto, ok = dtoVal.(*T)
|
||||||
|
if !ok {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
dto, ok := dtoVal.(*T)
|
if dtoTypeInContext == nil {
|
||||||
return dto, ok
|
needsValidation = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
var decoded T
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&decoded); err != nil {
|
||||||
|
SendErrorResponse(w, "Invalid JSON", http.StatusBadRequest)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
dto = &decoded
|
||||||
|
needsValidation = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if needsValidation {
|
||||||
|
if err := validation.ValidateStruct(dto); err != nil {
|
||||||
|
var errorMessages []string
|
||||||
|
if structErr, ok := err.(*validation.StructValidationError); ok {
|
||||||
|
errorMessages = make([]string, len(structErr.Errors))
|
||||||
|
for i, fieldError := range structErr.Errors {
|
||||||
|
errorMessages[i] = fieldError.Message
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
errorMessages = []string{err.Error()}
|
||||||
|
}
|
||||||
|
|
||||||
|
errorMsg := strings.Join(errorMessages, "; ")
|
||||||
|
SendErrorResponse(w, errorMsg, http.StatusBadRequest)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return dto, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithValidation[T any](validationMiddleware func(http.Handler) http.Handler, handler http.HandlerFunc) http.HandlerFunc {
|
func WithValidation[T any](validationMiddleware func(http.Handler) http.Handler, handler http.HandlerFunc) http.HandlerFunc {
|
||||||
|
|||||||
@@ -35,25 +35,25 @@ func FuzzJSONParsing(f *testing.F) {
|
|||||||
func FuzzURLParsing(f *testing.F) {
|
func FuzzURLParsing(f *testing.F) {
|
||||||
helper := fuzz.NewFuzzTestHelper()
|
helper := fuzz.NewFuzzTestHelper()
|
||||||
helper.RunBasicFuzzTest(f, func(t *testing.T, input string) {
|
helper.RunBasicFuzzTest(f, func(t *testing.T, input string) {
|
||||||
|
var sanitized strings.Builder
|
||||||
sanitized := ""
|
sanitized.Grow(len(input))
|
||||||
|
sanitizedLen := 0
|
||||||
for _, char := range input {
|
for _, char := range input {
|
||||||
|
|
||||||
if (char >= 'a' && char <= 'z') || (char >= 'A' && char <= 'Z') ||
|
if (char >= 'a' && char <= 'z') || (char >= 'A' && char <= 'Z') ||
|
||||||
(char >= '0' && char <= '9') || char == '-' || char == '_' {
|
(char >= '0' && char <= '9') || char == '-' || char == '_' {
|
||||||
sanitized += string(char)
|
sanitized.WriteRune(char)
|
||||||
|
sanitizedLen++
|
||||||
|
if sanitizedLen >= 20 {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(sanitized) > 20 {
|
if sanitizedLen == 0 {
|
||||||
sanitized = sanitized[:20]
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(sanitized) == 0 {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
url := "/api/posts/" + sanitized
|
url := "/api/posts/" + sanitized.String()
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", url, nil)
|
||||||
|
|
||||||
pathParts := strings.Split(req.URL.Path, "/")
|
pathParts := strings.Split(req.URL.Path, "/")
|
||||||
@@ -67,46 +67,52 @@ func FuzzURLParsing(f *testing.F) {
|
|||||||
func FuzzQueryParameters(f *testing.F) {
|
func FuzzQueryParameters(f *testing.F) {
|
||||||
helper := fuzz.NewFuzzTestHelper()
|
helper := fuzz.NewFuzzTestHelper()
|
||||||
helper.RunBasicFuzzTest(f, func(t *testing.T, input string) {
|
helper.RunBasicFuzzTest(f, func(t *testing.T, input string) {
|
||||||
|
|
||||||
if !utf8.ValidString(input) {
|
if !utf8.ValidString(input) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sanitized := ""
|
var sanitized strings.Builder
|
||||||
|
sanitized.Grow(len(input))
|
||||||
|
sanitizedLen := 0
|
||||||
for _, char := range input {
|
for _, char := range input {
|
||||||
|
|
||||||
if char >= 32 && char <= 126 {
|
if char >= 32 && char <= 126 {
|
||||||
switch char {
|
switch char {
|
||||||
case ' ', '\n', '\r', '\t':
|
case ' ', '\n', '\r', '\t':
|
||||||
|
|
||||||
continue
|
continue
|
||||||
case '&':
|
case '&':
|
||||||
sanitized += "%26"
|
sanitized.WriteString("%26")
|
||||||
|
sanitizedLen += 3
|
||||||
case '=':
|
case '=':
|
||||||
sanitized += "%3D"
|
sanitized.WriteString("%3D")
|
||||||
|
sanitizedLen += 3
|
||||||
case '?':
|
case '?':
|
||||||
sanitized += "%3F"
|
sanitized.WriteString("%3F")
|
||||||
|
sanitizedLen += 3
|
||||||
case '#':
|
case '#':
|
||||||
sanitized += "%23"
|
sanitized.WriteString("%23")
|
||||||
|
sanitizedLen += 3
|
||||||
case '/':
|
case '/':
|
||||||
sanitized += "%2F"
|
sanitized.WriteString("%2F")
|
||||||
|
sanitizedLen += 3
|
||||||
case '\\':
|
case '\\':
|
||||||
sanitized += "%5C"
|
sanitized.WriteString("%5C")
|
||||||
|
sanitizedLen += 3
|
||||||
default:
|
default:
|
||||||
sanitized += string(char)
|
sanitized.WriteRune(char)
|
||||||
|
sanitizedLen++
|
||||||
|
}
|
||||||
|
if sanitizedLen >= 100 {
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(sanitized) > 100 {
|
if sanitizedLen == 0 {
|
||||||
sanitized = sanitized[:100]
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(sanitized) == 0 {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
query := "?q=" + sanitized + "&limit=10&offset=0"
|
query := "?q=" + sanitized.String() + "&limit=10&offset=0"
|
||||||
req := httptest.NewRequest("GET", "/api/posts/search"+query, nil)
|
req := httptest.NewRequest("GET", "/api/posts/search"+query, nil)
|
||||||
|
|
||||||
q := req.URL.Query().Get("q")
|
q := req.URL.Query().Get("q")
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"goyco/internal/repositories"
|
"goyco/internal/repositories"
|
||||||
"goyco/internal/security"
|
"goyco/internal/security"
|
||||||
"goyco/internal/services"
|
"goyco/internal/services"
|
||||||
"goyco/internal/validation"
|
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/jackc/pgconn"
|
"github.com/jackc/pgconn"
|
||||||
@@ -63,13 +62,8 @@ func (h *PostHandler) GetPosts(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
postDTOs := dto.ToPostDTOs(posts)
|
responseDTO := dto.ToPostListDTO(posts, limit, offset)
|
||||||
SendSuccessResponse(w, "Posts retrieved successfully", map[string]any{
|
SendSuccessResponse(w, "Posts retrieved successfully", responseDTO)
|
||||||
"posts": postDTOs,
|
|
||||||
"count": len(postDTOs),
|
|
||||||
"limit": limit,
|
|
||||||
"offset": offset,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Summary Get a single post
|
// @Summary Get a single post
|
||||||
@@ -115,9 +109,8 @@ func (h *PostHandler) GetPost(w http.ResponseWriter, r *http.Request) {
|
|||||||
// @Failure 500 {object} PostResponse "Internal server error"
|
// @Failure 500 {object} PostResponse "Internal server error"
|
||||||
// @Router /api/posts [post]
|
// @Router /api/posts [post]
|
||||||
func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
|
func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
|
||||||
req, ok := GetValidatedDTO[dto.CreatePostRequest](r)
|
request, ok := GetValidatedDTO[dto.CreatePostRequest](w, r)
|
||||||
if !ok {
|
if !ok {
|
||||||
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,9 +119,9 @@ func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
title := security.SanitizeInput(req.Title)
|
title := security.SanitizeInput(request.Title)
|
||||||
url := security.SanitizeURL(req.URL)
|
url := security.SanitizeURL(request.URL)
|
||||||
content := security.SanitizePostContent(req.Content)
|
content := security.SanitizePostContent(request.Content)
|
||||||
|
|
||||||
if url == "" {
|
if url == "" {
|
||||||
SendErrorResponse(w, "Invalid URL", http.StatusBadRequest)
|
SendErrorResponse(w, "Invalid URL", http.StatusBadRequest)
|
||||||
@@ -229,14 +222,8 @@ func (h *PostHandler) SearchPosts(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
postDTOs := dto.ToPostDTOs(posts)
|
responseDTO := dto.ToSearchPostListDTO(posts, query, limit, offset)
|
||||||
SendSuccessResponse(w, "Search results retrieved successfully", map[string]any{
|
SendSuccessResponse(w, "Search results retrieved successfully", responseDTO)
|
||||||
"posts": postDTOs,
|
|
||||||
"count": len(postDTOs),
|
|
||||||
"query": query,
|
|
||||||
"limit": limit,
|
|
||||||
"offset": offset,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Summary Update a post
|
// @Summary Update a post
|
||||||
@@ -275,24 +262,13 @@ func (h *PostHandler) UpdatePost(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req, ok := GetValidatedDTO[dto.UpdatePostRequest](r)
|
request, ok := GetValidatedDTO[dto.UpdatePostRequest](w, r)
|
||||||
if !ok {
|
if !ok {
|
||||||
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
title := security.SanitizeInput(req.Title)
|
title := security.SanitizeInput(request.Title)
|
||||||
content := security.SanitizePostContent(req.Content)
|
content := security.SanitizePostContent(request.Content)
|
||||||
|
|
||||||
if err := validation.ValidateTitle(title); err != nil {
|
|
||||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := validation.ValidateContent(content); err != nil {
|
|
||||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
post.Title = title
|
post.Title = title
|
||||||
post.Content = content
|
post.Content = content
|
||||||
@@ -351,7 +327,7 @@ func (h *PostHandler) DeletePost(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
SendSuccessResponse(w, "Post deleted successfully", nil)
|
SendSuccessResponse(w, "Post deleted successfully", dto.EmptyResponseDTO{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Summary Fetch title from URL
|
// @Summary Fetch title from URL
|
||||||
@@ -393,9 +369,10 @@ func (h *PostHandler) FetchTitleFromURL(w http.ResponseWriter, r *http.Request)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
SendSuccessResponse(w, "Title fetched successfully", map[string]string{
|
responseDTO := dto.TitleResponseDTO{
|
||||||
"title": title,
|
Title: title,
|
||||||
})
|
}
|
||||||
|
SendSuccessResponse(w, "Title fetched successfully", responseDTO)
|
||||||
}
|
}
|
||||||
|
|
||||||
func translatePostCreateError(err error) (string, int) {
|
func translatePostCreateError(err error) (string, int) {
|
||||||
|
|||||||
@@ -10,14 +10,15 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
"github.com/jackc/pgconn"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"goyco/internal/database"
|
"goyco/internal/database"
|
||||||
"goyco/internal/middleware"
|
"goyco/internal/middleware"
|
||||||
"goyco/internal/repositories"
|
"goyco/internal/repositories"
|
||||||
"goyco/internal/services"
|
"goyco/internal/services"
|
||||||
"goyco/internal/testutils"
|
"goyco/internal/testutils"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func decodeHandlerResponse(t *testing.T, rr *httptest.ResponseRecorder) map[string]any {
|
func decodeHandlerResponse(t *testing.T, rr *httptest.ResponseRecorder) map[string]any {
|
||||||
@@ -277,7 +278,7 @@ func TestPostHandlerCreatePostSuccess(t *testing.T) {
|
|||||||
|
|
||||||
handler := NewPostHandler(repo, fetcher, nil)
|
handler := NewPostHandler(repo, fetcher, nil)
|
||||||
|
|
||||||
request := createCreatePostRequest(`{"title":" ","url":"https://example.com","content":"Go"}`)
|
request := createCreatePostRequest(`{"title":"","url":"https://example.com","content":"Go"}`)
|
||||||
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(42))
|
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(42))
|
||||||
request = request.WithContext(ctx)
|
request = request.WithContext(ctx)
|
||||||
|
|
||||||
@@ -310,7 +311,7 @@ func TestPostHandlerCreatePostValidation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
recorder = httptest.NewRecorder()
|
recorder = httptest.NewRecorder()
|
||||||
request = createCreatePostRequest(`{"title":"ok","url":"https://example.com"}`)
|
request = createCreatePostRequest(`{"title":"okay","url":"https://example.com"}`)
|
||||||
handler.CreatePost(recorder, request)
|
handler.CreatePost(recorder, request)
|
||||||
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
|
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
|
||||||
}
|
}
|
||||||
@@ -334,7 +335,7 @@ func TestPostHandlerCreatePostTitleFetcherErrors(t *testing.T) {
|
|||||||
return "", tc.err
|
return "", tc.err
|
||||||
}}
|
}}
|
||||||
handler := NewPostHandler(repo, fetcher, nil)
|
handler := NewPostHandler(repo, fetcher, nil)
|
||||||
request := createCreatePostRequest(`{"title":" ","url":"https://example.com"}`)
|
request := createCreatePostRequest(`{"title":"","url":"https://example.com"}`)
|
||||||
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
|
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -466,7 +467,7 @@ func TestPostHandlerUpdatePost(t *testing.T) {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
expectedStatus: http.StatusBadRequest,
|
expectedStatus: http.StatusBadRequest,
|
||||||
expectedError: "Title is required",
|
expectedError: "title is required",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "short title",
|
name: "short title",
|
||||||
@@ -480,7 +481,7 @@ func TestPostHandlerUpdatePost(t *testing.T) {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
expectedStatus: http.StatusBadRequest,
|
expectedStatus: http.StatusBadRequest,
|
||||||
expectedError: "Title must be at least 3 characters",
|
expectedError: "title must be at least 3 characters",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -46,14 +46,8 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userDTOs := dto.ToSanitizedUserDTOs(users)
|
responseDTO := dto.ToSanitizedUserListDTO(users, limit, offset)
|
||||||
|
SendSuccessResponse(w, "Users retrieved successfully", responseDTO)
|
||||||
SendSuccessResponse(w, "Users retrieved successfully", map[string]any{
|
|
||||||
"users": userDTOs,
|
|
||||||
"count": len(userDTOs),
|
|
||||||
"limit": limit,
|
|
||||||
"offset": offset,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Summary Get user
|
// @Summary Get user
|
||||||
@@ -99,28 +93,12 @@ func (h *UserHandler) GetUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
// @Failure 500 {object} UserResponse "Internal server error"
|
// @Failure 500 {object} UserResponse "Internal server error"
|
||||||
// @Router /api/users [post]
|
// @Router /api/users [post]
|
||||||
func (h *UserHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
func (h *UserHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
||||||
req, ok := GetValidatedDTO[dto.RegisterRequest](r)
|
request, ok := GetValidatedDTO[dto.RegisterRequest](w, r)
|
||||||
if !ok {
|
if !ok {
|
||||||
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := validation.ValidateUsername(req.Username); err != nil {
|
result, err := h.authService.Register(request.Username, request.Email, request.Password)
|
||||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := validation.ValidateEmail(req.Email); err != nil {
|
|
||||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := validation.ValidatePassword(req.Password); err != nil {
|
|
||||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := h.authService.Register(req.Username, req.Email, req.Password)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var validationErr *validation.ValidationError
|
var validationErr *validation.ValidationError
|
||||||
if errors.As(err, &validationErr) {
|
if errors.As(err, &validationErr) {
|
||||||
@@ -132,10 +110,8 @@ func (h *UserHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SendCreatedResponse(w, "User created successfully. Verification email sent.", map[string]any{
|
responseDTO := dto.ToRegistrationResponseDTO(result.User, result.VerificationSent)
|
||||||
"user": result.User,
|
SendCreatedResponse(w, "User created successfully. Verification email sent.", responseDTO)
|
||||||
"verification_sent": result.VerificationSent,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Summary Get user posts
|
// @Summary Get user posts
|
||||||
@@ -166,13 +142,8 @@ func (h *UserHandler) GetUserPosts(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
postDTOs := dto.ToPostDTOs(posts)
|
responseDTO := dto.ToPostListDTO(posts, limit, offset)
|
||||||
SendSuccessResponse(w, "User posts retrieved successfully", map[string]any{
|
SendSuccessResponse(w, "User posts retrieved successfully", responseDTO)
|
||||||
"posts": postDTOs,
|
|
||||||
"count": len(postDTOs),
|
|
||||||
"limit": limit,
|
|
||||||
"offset": offset,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *UserHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
|
func (h *UserHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
|
||||||
|
|||||||
@@ -78,14 +78,13 @@ func (h *VoteHandler) CastVote(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req, ok := GetValidatedDTO[dto.CastVoteRequest](r)
|
request, ok := GetValidatedDTO[dto.CastVoteRequest](w, r)
|
||||||
if !ok {
|
if !ok {
|
||||||
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var voteType database.VoteType
|
var voteType database.VoteType
|
||||||
switch req.Type {
|
switch request.Type {
|
||||||
case "up":
|
case "up":
|
||||||
voteType = database.VoteUp
|
voteType = database.VoteUp
|
||||||
case "down":
|
case "down":
|
||||||
@@ -213,22 +212,26 @@ func (h *VoteHandler) GetUserVote(w http.ResponseWriter, r *http.Request) {
|
|||||||
vote, err := h.voteService.GetUserVote(userID, postID, ipAddress, userAgent)
|
vote, err := h.voteService.GetUserVote(userID, postID, ipAddress, userAgent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.Error() == "record not found" {
|
if err.Error() == "record not found" {
|
||||||
SendSuccessResponse(w, "No vote found", map[string]any{
|
responseDTO := dto.VoteResponseDTO{
|
||||||
"has_vote": false,
|
HasVote: false,
|
||||||
"vote": nil,
|
Vote: nil,
|
||||||
"is_anonymous": false,
|
IsAnonymous: false,
|
||||||
})
|
}
|
||||||
|
SendSuccessResponse(w, "No vote found", responseDTO)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
SendErrorResponse(w, "Internal server error", http.StatusInternalServerError)
|
SendErrorResponse(w, "Internal server error", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
SendSuccessResponse(w, "Vote retrieved successfully", map[string]any{
|
voteDTO := dto.ToVoteDTO(vote)
|
||||||
"has_vote": true,
|
isAnonymous := vote.UserID == nil
|
||||||
"vote": vote,
|
responseDTO := dto.VoteResponseDTO{
|
||||||
"is_anonymous": false,
|
HasVote: true,
|
||||||
})
|
Vote: &voteDTO,
|
||||||
|
IsAnonymous: isAnonymous,
|
||||||
|
}
|
||||||
|
SendSuccessResponse(w, "Vote retrieved successfully", responseDTO)
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Summary Get post votes
|
// @Summary Get post votes
|
||||||
@@ -263,15 +266,12 @@ func (h *VoteHandler) GetPostVotes(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
allVotes := make([]any, 0, len(votes))
|
voteDTOs := dto.ToVoteDTOs(votes)
|
||||||
for _, vote := range votes {
|
responseDTO := dto.VoteListDTO{
|
||||||
allVotes = append(allVotes, vote)
|
Votes: voteDTOs,
|
||||||
|
Count: len(voteDTOs),
|
||||||
}
|
}
|
||||||
|
SendSuccessResponse(w, "Votes retrieved successfully", responseDTO)
|
||||||
SendSuccessResponse(w, "Votes retrieved successfully", map[string]any{
|
|
||||||
"votes": allVotes,
|
|
||||||
"count": len(allVotes),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *VoteHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
|
func (h *VoteHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
|
||||||
|
|||||||
@@ -589,18 +589,22 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
|
|||||||
|
|
||||||
response := assertJSONResponse(t, request, http.StatusOK)
|
response := assertJSONResponse(t, request, http.StatusOK)
|
||||||
if data, ok := getDataFromResponse(response); ok {
|
if data, ok := getDataFromResponse(response); ok {
|
||||||
if newAccessToken, ok := data["access_token"].(string); ok {
|
newAccessToken, _ := data["access_token"].(string)
|
||||||
if newAccessToken == "" {
|
if newAccessToken == "" {
|
||||||
t.Error("Expected new access token in refresh response")
|
t.Error("Expected new access token in refresh response")
|
||||||
}
|
}
|
||||||
|
|
||||||
if newRefreshToken, ok := data["refresh_token"].(string); ok {
|
newRefreshToken, _ := data["refresh_token"].(string)
|
||||||
if newRefreshToken != "" && newRefreshToken == originalRefreshToken {
|
if newRefreshToken == "" {
|
||||||
t.Log("Refresh token rotation may not be implemented (same token returned)")
|
t.Error("Expected new refresh token in refresh response")
|
||||||
}
|
}
|
||||||
}
|
if newRefreshToken == originalRefreshToken {
|
||||||
|
t.Error("Expected refresh token to rotate")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
request = makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": originalRefreshToken})
|
||||||
|
assertErrorResponse(t, request, http.StatusUnauthorized)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Refresh_After_Account_Lock", func(t *testing.T) {
|
t.Run("Refresh_After_Account_Lock", func(t *testing.T) {
|
||||||
|
|||||||
@@ -610,6 +610,14 @@ func TestIntegration_Services(t *testing.T) {
|
|||||||
t.Error("New access token should be different from original")
|
t.Error("New access token should be different from original")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if newAccessToken.RefreshToken == "" {
|
||||||
|
t.Fatal("Refresh should return a new refresh token")
|
||||||
|
}
|
||||||
|
|
||||||
|
if newAccessToken.RefreshToken == loginResult.RefreshToken {
|
||||||
|
t.Error("Refresh token should rotate")
|
||||||
|
}
|
||||||
|
|
||||||
userID, err := authService.VerifyToken(newAccessToken.AccessToken)
|
userID, err := authService.VerifyToken(newAccessToken.AccessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("New access token should be valid: %v", err)
|
t.Fatalf("New access token should be valid: %v", err)
|
||||||
@@ -618,6 +626,11 @@ func TestIntegration_Services(t *testing.T) {
|
|||||||
if userID != user.ID {
|
if userID != user.ID {
|
||||||
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
|
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = authService.RefreshAccessToken(loginResult.RefreshToken)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for rotated refresh token")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Refresh_Token_Expiration", func(t *testing.T) {
|
t.Run("Refresh_Token_Expiration", func(t *testing.T) {
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
package repositories
|
package repositories
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"goyco/internal/database"
|
"goyco/internal/database"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDatabase_AssertUserExists(t *testing.T) {
|
func TestDatabase_AssertUserExists(t *testing.T) {
|
||||||
@@ -384,245 +381,3 @@ func TestDatabase_CreateTestAccountDeletionRequest(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestApplyPagination(t *testing.T) {
|
|
||||||
suite := NewTestSuite(t)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
limit int
|
|
||||||
offset int
|
|
||||||
setupQuery func(*gorm.DB) *gorm.DB
|
|
||||||
verifyPagination func(*testing.T, *gorm.DB, int, int)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "limit > 0 and offset > 0",
|
|
||||||
limit: 10,
|
|
||||||
offset: 5,
|
|
||||||
setupQuery: func(db *gorm.DB) *gorm.DB {
|
|
||||||
return db.Model(&database.User{})
|
|
||||||
},
|
|
||||||
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
|
|
||||||
for i := 0; i < 20; i++ {
|
|
||||||
suite.CreateTestUser(
|
|
||||||
fmt.Sprintf("testuser_%d", i),
|
|
||||||
fmt.Sprintf("user%d@example.com", i),
|
|
||||||
"password123",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
var users []database.User
|
|
||||||
result := query.Find(&users)
|
|
||||||
if result.Error != nil {
|
|
||||||
t.Fatalf("Query failed: %v", result.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(users) != limit {
|
|
||||||
t.Errorf("Expected %d users, got %d", limit, len(users))
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "limit > 0 and offset = 0",
|
|
||||||
limit: 5,
|
|
||||||
offset: 0,
|
|
||||||
setupQuery: func(db *gorm.DB) *gorm.DB {
|
|
||||||
return db.Model(&database.User{})
|
|
||||||
},
|
|
||||||
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
suite.CreateTestUser(
|
|
||||||
fmt.Sprintf("testuser_%d", i),
|
|
||||||
fmt.Sprintf("user%d@example.com", i),
|
|
||||||
"password123",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
var users []database.User
|
|
||||||
result := query.Find(&users)
|
|
||||||
if result.Error != nil {
|
|
||||||
t.Fatalf("Query failed: %v", result.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(users) != limit {
|
|
||||||
t.Errorf("Expected %d users, got %d", limit, len(users))
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "limit = 0 (should not apply limit)",
|
|
||||||
limit: 0,
|
|
||||||
offset: 5,
|
|
||||||
setupQuery: func(db *gorm.DB) *gorm.DB {
|
|
||||||
return db.Model(&database.User{})
|
|
||||||
},
|
|
||||||
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
suite.CreateTestUser(
|
|
||||||
fmt.Sprintf("testuser_%d", i),
|
|
||||||
fmt.Sprintf("user%d@example.com", i),
|
|
||||||
"password123",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
var users []database.User
|
|
||||||
result := query.Find(&users)
|
|
||||||
if result.Error != nil {
|
|
||||||
t.Fatalf("Query failed: %v", result.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
expected := 5
|
|
||||||
if len(users) != expected {
|
|
||||||
t.Errorf("Expected %d users with offset %d, got %d", expected, offset, len(users))
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "offset = 0 (should not apply offset)",
|
|
||||||
limit: 10,
|
|
||||||
offset: 0,
|
|
||||||
setupQuery: func(db *gorm.DB) *gorm.DB {
|
|
||||||
return db.Model(&database.User{})
|
|
||||||
},
|
|
||||||
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
|
|
||||||
for i := 0; i < 15; i++ {
|
|
||||||
suite.CreateTestUser(
|
|
||||||
fmt.Sprintf("testuser_%d", i),
|
|
||||||
fmt.Sprintf("user%d@example.com", i),
|
|
||||||
"password123",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
var users []database.User
|
|
||||||
result := query.Find(&users)
|
|
||||||
if result.Error != nil {
|
|
||||||
t.Fatalf("Query failed: %v", result.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(users) != limit {
|
|
||||||
t.Errorf("Expected %d users, got %d", limit, len(users))
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "limit = 0 and offset = 0 (should not apply pagination)",
|
|
||||||
limit: 0,
|
|
||||||
offset: 0,
|
|
||||||
setupQuery: func(db *gorm.DB) *gorm.DB {
|
|
||||||
return db.Model(&database.User{})
|
|
||||||
},
|
|
||||||
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
suite.CreateTestUser(
|
|
||||||
fmt.Sprintf("testuser_%d", i),
|
|
||||||
fmt.Sprintf("user%d@example.com", i),
|
|
||||||
"password123",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
var users []database.User
|
|
||||||
result := query.Find(&users)
|
|
||||||
if result.Error != nil {
|
|
||||||
t.Fatalf("Query failed: %v", result.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(users) != 10 {
|
|
||||||
t.Errorf("Expected all 10 users, got %d", len(users))
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "negative limit (should not apply limit)",
|
|
||||||
limit: -5,
|
|
||||||
offset: 10,
|
|
||||||
setupQuery: func(db *gorm.DB) *gorm.DB {
|
|
||||||
return db.Model(&database.User{})
|
|
||||||
},
|
|
||||||
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
|
|
||||||
for i := 0; i < 20; i++ {
|
|
||||||
suite.CreateTestUser(
|
|
||||||
fmt.Sprintf("testuser_%d", i),
|
|
||||||
fmt.Sprintf("user%d@example.com", i),
|
|
||||||
"password123",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
var users []database.User
|
|
||||||
result := query.Find(&users)
|
|
||||||
if result.Error != nil {
|
|
||||||
t.Fatalf("Query failed: %v", result.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
expected := 10
|
|
||||||
if len(users) != expected {
|
|
||||||
t.Errorf("Expected %d users with offset %d, got %d", expected, offset, len(users))
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "negative offset (should not apply offset)",
|
|
||||||
limit: 10,
|
|
||||||
offset: -5,
|
|
||||||
setupQuery: func(db *gorm.DB) *gorm.DB {
|
|
||||||
return db.Model(&database.User{})
|
|
||||||
},
|
|
||||||
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
|
|
||||||
for i := 0; i < 15; i++ {
|
|
||||||
suite.CreateTestUser(
|
|
||||||
fmt.Sprintf("testuser_%d", i),
|
|
||||||
fmt.Sprintf("user%d@example.com", i),
|
|
||||||
"password123",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
var users []database.User
|
|
||||||
result := query.Find(&users)
|
|
||||||
if result.Error != nil {
|
|
||||||
t.Fatalf("Query failed: %v", result.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(users) != limit {
|
|
||||||
t.Errorf("Expected %d users, got %d", limit, len(users))
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "large limit and offset values",
|
|
||||||
limit: 1000,
|
|
||||||
offset: 500,
|
|
||||||
setupQuery: func(db *gorm.DB) *gorm.DB {
|
|
||||||
return db.Model(&database.User{})
|
|
||||||
},
|
|
||||||
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
|
|
||||||
for i := 0; i < 2000; i++ {
|
|
||||||
suite.CreateTestUser(
|
|
||||||
fmt.Sprintf("testuser_%d", i),
|
|
||||||
fmt.Sprintf("user%d@example.com", i),
|
|
||||||
"password123",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
var users []database.User
|
|
||||||
result := query.Find(&users)
|
|
||||||
if result.Error != nil {
|
|
||||||
t.Fatalf("Query failed: %v", result.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(users) != limit {
|
|
||||||
t.Errorf("Expected %d users, got %d", limit, len(users))
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
suite.Reset()
|
|
||||||
|
|
||||||
baseQuery := tt.setupQuery(suite.DB)
|
|
||||||
paginatedQuery := ApplyPagination(baseQuery, tt.limit, tt.offset)
|
|
||||||
|
|
||||||
tt.verifyPagination(t, paginatedQuery, tt.limit, tt.offset)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
252
internal/repositories/pagination_test.go
Normal file
252
internal/repositories/pagination_test.go
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
package repositories
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"goyco/internal/database"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestApplyPagination(t *testing.T) {
|
||||||
|
suite := NewTestSuite(t)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
limit int
|
||||||
|
offset int
|
||||||
|
setupQuery func(*gorm.DB) *gorm.DB
|
||||||
|
verifyPagination func(*testing.T, *gorm.DB, int, int)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "limit > 0 and offset > 0",
|
||||||
|
limit: 10,
|
||||||
|
offset: 5,
|
||||||
|
setupQuery: func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Model(&database.User{})
|
||||||
|
},
|
||||||
|
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
|
||||||
|
for i := range 20 {
|
||||||
|
suite.CreateTestUser(
|
||||||
|
fmt.Sprintf("testuser_%d", i),
|
||||||
|
fmt.Sprintf("user%d@example.com", i),
|
||||||
|
"password123",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
var users []database.User
|
||||||
|
result := query.Find(&users)
|
||||||
|
if result.Error != nil {
|
||||||
|
t.Fatalf("Query failed: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(users) != limit {
|
||||||
|
t.Errorf("Expected %d users, got %d", limit, len(users))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "limit > 0 and offset = 0",
|
||||||
|
limit: 5,
|
||||||
|
offset: 0,
|
||||||
|
setupQuery: func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Model(&database.User{})
|
||||||
|
},
|
||||||
|
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
|
||||||
|
for i := range 10 {
|
||||||
|
suite.CreateTestUser(
|
||||||
|
fmt.Sprintf("testuser_%d", i),
|
||||||
|
fmt.Sprintf("user%d@example.com", i),
|
||||||
|
"password123",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
var users []database.User
|
||||||
|
result := query.Find(&users)
|
||||||
|
if result.Error != nil {
|
||||||
|
t.Fatalf("Query failed: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(users) != limit {
|
||||||
|
t.Errorf("Expected %d users, got %d", limit, len(users))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "limit = 0 (should not apply limit)",
|
||||||
|
limit: 0,
|
||||||
|
offset: 5,
|
||||||
|
setupQuery: func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Model(&database.User{})
|
||||||
|
},
|
||||||
|
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
|
||||||
|
for i := range 10 {
|
||||||
|
suite.CreateTestUser(
|
||||||
|
fmt.Sprintf("testuser_%d", i),
|
||||||
|
fmt.Sprintf("user%d@example.com", i),
|
||||||
|
"password123",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
var users []database.User
|
||||||
|
result := query.Find(&users)
|
||||||
|
if result.Error != nil {
|
||||||
|
t.Fatalf("Query failed: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := 5
|
||||||
|
if len(users) != expected {
|
||||||
|
t.Errorf("Expected %d users with offset %d, got %d", expected, offset, len(users))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "offset = 0 (should not apply offset)",
|
||||||
|
limit: 10,
|
||||||
|
offset: 0,
|
||||||
|
setupQuery: func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Model(&database.User{})
|
||||||
|
},
|
||||||
|
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
|
||||||
|
for i := range 15 {
|
||||||
|
suite.CreateTestUser(
|
||||||
|
fmt.Sprintf("testuser_%d", i),
|
||||||
|
fmt.Sprintf("user%d@example.com", i),
|
||||||
|
"password123",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
var users []database.User
|
||||||
|
result := query.Find(&users)
|
||||||
|
if result.Error != nil {
|
||||||
|
t.Fatalf("Query failed: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(users) != limit {
|
||||||
|
t.Errorf("Expected %d users, got %d", limit, len(users))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "limit = 0 and offset = 0 (should not apply pagination)",
|
||||||
|
limit: 0,
|
||||||
|
offset: 0,
|
||||||
|
setupQuery: func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Model(&database.User{})
|
||||||
|
},
|
||||||
|
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
|
||||||
|
for i := range 10 {
|
||||||
|
suite.CreateTestUser(
|
||||||
|
fmt.Sprintf("testuser_%d", i),
|
||||||
|
fmt.Sprintf("user%d@example.com", i),
|
||||||
|
"password123",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
var users []database.User
|
||||||
|
result := query.Find(&users)
|
||||||
|
if result.Error != nil {
|
||||||
|
t.Fatalf("Query failed: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(users) != 10 {
|
||||||
|
t.Errorf("Expected all 10 users, got %d", len(users))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative limit (should not apply limit)",
|
||||||
|
limit: -5,
|
||||||
|
offset: 10,
|
||||||
|
setupQuery: func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Model(&database.User{})
|
||||||
|
},
|
||||||
|
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
|
||||||
|
for i := range 20 {
|
||||||
|
suite.CreateTestUser(
|
||||||
|
fmt.Sprintf("testuser_%d", i),
|
||||||
|
fmt.Sprintf("user%d@example.com", i),
|
||||||
|
"password123",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
var users []database.User
|
||||||
|
result := query.Find(&users)
|
||||||
|
if result.Error != nil {
|
||||||
|
t.Fatalf("Query failed: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := 10
|
||||||
|
if len(users) != expected {
|
||||||
|
t.Errorf("Expected %d users with offset %d, got %d", expected, offset, len(users))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative offset (should not apply offset)",
|
||||||
|
limit: 10,
|
||||||
|
offset: -5,
|
||||||
|
setupQuery: func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Model(&database.User{})
|
||||||
|
},
|
||||||
|
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
|
||||||
|
for i := range 15 {
|
||||||
|
suite.CreateTestUser(
|
||||||
|
fmt.Sprintf("testuser_%d", i),
|
||||||
|
fmt.Sprintf("user%d@example.com", i),
|
||||||
|
"password123",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
var users []database.User
|
||||||
|
result := query.Find(&users)
|
||||||
|
if result.Error != nil {
|
||||||
|
t.Fatalf("Query failed: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(users) != limit {
|
||||||
|
t.Errorf("Expected %d users, got %d", limit, len(users))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "large limit and offset values",
|
||||||
|
limit: 1000,
|
||||||
|
offset: 500,
|
||||||
|
setupQuery: func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Model(&database.User{})
|
||||||
|
},
|
||||||
|
verifyPagination: func(t *testing.T, query *gorm.DB, limit, offset int) {
|
||||||
|
for i := range 2000 {
|
||||||
|
suite.CreateTestUser(
|
||||||
|
fmt.Sprintf("testuser_%d", i),
|
||||||
|
fmt.Sprintf("user%d@example.com", i),
|
||||||
|
"password123",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
var users []database.User
|
||||||
|
result := query.Find(&users)
|
||||||
|
if result.Error != nil {
|
||||||
|
t.Fatalf("Query failed: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(users) != limit {
|
||||||
|
t.Errorf("Expected %d users, got %d", limit, len(users))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
suite.Reset()
|
||||||
|
|
||||||
|
baseQuery := tt.setupQuery(suite.DB)
|
||||||
|
paginatedQuery := ApplyPagination(baseQuery, tt.limit, tt.offset)
|
||||||
|
|
||||||
|
tt.verifyPagination(t, paginatedQuery, tt.limit, tt.offset)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -123,46 +123,27 @@ func (j *JWTService) VerifyAccessToken(tokenString string) (uint, error) {
|
|||||||
return claims.UserID, nil
|
return claims.UserID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (j *JWTService) RefreshAccessToken(refreshTokenString string) (string, error) {
|
func (j *JWTService) RefreshAccessTokenWithRotation(refreshTokenString string) (string, string, error) {
|
||||||
|
refreshToken, user, err := j.validateRefreshToken(refreshTokenString)
|
||||||
tokenHash := j.hashToken(refreshTokenString)
|
|
||||||
|
|
||||||
refreshToken, err := j.refreshRepo.GetByTokenHash(tokenHash)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
return "", "", err
|
||||||
return "", ErrRefreshTokenInvalid
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("lookup refresh token: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if time.Now().After(refreshToken.ExpiresAt) {
|
if err := j.refreshRepo.DeleteByID(refreshToken.ID); err != nil {
|
||||||
|
return "", "", fmt.Errorf("revoke refresh token: %w", err)
|
||||||
j.refreshRepo.DeleteByID(refreshToken.ID)
|
|
||||||
return "", ErrRefreshTokenExpired
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := j.userRepo.GetByID(refreshToken.UserID)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
||||||
|
|
||||||
j.refreshRepo.DeleteByID(refreshToken.ID)
|
|
||||||
return "", ErrRefreshTokenInvalid
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("lookup user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if user.Locked {
|
|
||||||
|
|
||||||
j.refreshRepo.DeleteByID(refreshToken.ID)
|
|
||||||
return "", ErrAccountLocked
|
|
||||||
}
|
}
|
||||||
|
|
||||||
accessToken, err := j.GenerateAccessToken(user)
|
accessToken, err := j.GenerateAccessToken(user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("generate access token: %w", err)
|
return "", "", fmt.Errorf("generate access token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return accessToken, nil
|
newRefreshToken, err := j.GenerateRefreshToken(user)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("generate refresh token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return accessToken, newRefreshToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (j *JWTService) RevokeRefreshToken(refreshTokenString string) error {
|
func (j *JWTService) RevokeRefreshToken(refreshTokenString string) error {
|
||||||
@@ -354,6 +335,39 @@ func (j *JWTService) validateTokenMetadata(token *jwt.Token, claims *TokenClaims
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (j *JWTService) validateRefreshToken(refreshTokenString string) (*database.RefreshToken, *database.User, error) {
|
||||||
|
tokenHash := j.hashToken(refreshTokenString)
|
||||||
|
|
||||||
|
refreshToken, err := j.refreshRepo.GetByTokenHash(tokenHash)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil, ErrRefreshTokenInvalid
|
||||||
|
}
|
||||||
|
return nil, nil, fmt.Errorf("lookup refresh token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().After(refreshToken.ExpiresAt) {
|
||||||
|
_ = j.refreshRepo.DeleteByID(refreshToken.ID)
|
||||||
|
return nil, nil, ErrRefreshTokenExpired
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := j.userRepo.GetByID(refreshToken.UserID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
_ = j.refreshRepo.DeleteByID(refreshToken.ID)
|
||||||
|
return nil, nil, ErrRefreshTokenInvalid
|
||||||
|
}
|
||||||
|
return nil, nil, fmt.Errorf("lookup user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Locked {
|
||||||
|
_ = j.refreshRepo.DeleteByID(refreshToken.ID)
|
||||||
|
return nil, nil, ErrAccountLocked
|
||||||
|
}
|
||||||
|
|
||||||
|
return refreshToken, user, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (j *JWTService) hashToken(token string) string {
|
func (j *JWTService) hashToken(token string) string {
|
||||||
hash := sha256.Sum256([]byte(token))
|
hash := sha256.Sum256([]byte(token))
|
||||||
return hex.EncodeToString(hash[:])
|
return hex.EncodeToString(hash[:])
|
||||||
|
|||||||
@@ -348,19 +348,17 @@ func TestJWTService_VerifyAccessToken(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestJWTService_RefreshAccessToken(t *testing.T) {
|
func TestJWTService_RefreshAccessTokenWithRotation(t *testing.T) {
|
||||||
jwtService, userRepo, refreshRepo := createTestJWTService()
|
jwtService, userRepo, refreshRepo := createTestJWTService()
|
||||||
user := createTestUser()
|
user := createTestUser()
|
||||||
userRepo.users[user.ID] = user
|
userRepo.users[user.ID] = user
|
||||||
|
|
||||||
t.Run("Successful_Refresh", func(t *testing.T) {
|
|
||||||
|
|
||||||
refreshToken, err := jwtService.GenerateRefreshToken(user)
|
refreshToken, err := jwtService.GenerateRefreshToken(user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to generate refresh token: %v", err)
|
t.Fatalf("Failed to generate refresh token: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
accessToken, err := jwtService.RefreshAccessToken(refreshToken)
|
accessToken, newRefreshToken, err := jwtService.RefreshAccessTokenWithRotation(refreshToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Expected successful token refresh, got error: %v", err)
|
t.Fatalf("Expected successful token refresh, got error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -368,19 +366,36 @@ func TestJWTService_RefreshAccessToken(t *testing.T) {
|
|||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
t.Error("Expected non-empty access token")
|
t.Error("Expected non-empty access token")
|
||||||
}
|
}
|
||||||
|
if newRefreshToken == "" {
|
||||||
|
t.Error("Expected non-empty refresh token")
|
||||||
|
}
|
||||||
|
if newRefreshToken == refreshToken {
|
||||||
|
t.Error("Expected refresh token to rotate")
|
||||||
|
}
|
||||||
|
|
||||||
userID, err := jwtService.VerifyAccessToken(accessToken)
|
userID, err := jwtService.VerifyAccessToken(accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Expected valid access token, got error: %v", err)
|
t.Fatalf("Expected valid access token, got error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if userID != user.ID {
|
if userID != user.ID {
|
||||||
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
|
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
|
||||||
}
|
}
|
||||||
})
|
|
||||||
|
_, _, err = jwtService.RefreshAccessTokenWithRotation(refreshToken)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error for rotated refresh token")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrRefreshTokenInvalid) {
|
||||||
|
t.Errorf("Expected ErrRefreshTokenInvalid, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err = jwtService.RefreshAccessTokenWithRotation(newRefreshToken)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected new refresh token to be usable, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
t.Run("Invalid_Refresh_Token", func(t *testing.T) {
|
t.Run("Invalid_Refresh_Token", func(t *testing.T) {
|
||||||
_, err := jwtService.RefreshAccessToken("invalid-refresh-token")
|
_, _, err := jwtService.RefreshAccessTokenWithRotation("invalid-refresh-token")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected error for invalid refresh token")
|
t.Error("Expected error for invalid refresh token")
|
||||||
}
|
}
|
||||||
@@ -390,7 +405,6 @@ func TestJWTService_RefreshAccessToken(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Expired_Refresh_Token", func(t *testing.T) {
|
t.Run("Expired_Refresh_Token", func(t *testing.T) {
|
||||||
|
|
||||||
refreshToken := &database.RefreshToken{
|
refreshToken := &database.RefreshToken{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
TokenHash: "expired-token-hash",
|
TokenHash: "expired-token-hash",
|
||||||
@@ -403,7 +417,7 @@ func TestJWTService_RefreshAccessToken(t *testing.T) {
|
|||||||
refreshToken.TokenHash = tokenHash
|
refreshToken.TokenHash = tokenHash
|
||||||
refreshRepo.tokens[tokenHash] = refreshToken
|
refreshRepo.tokens[tokenHash] = refreshToken
|
||||||
|
|
||||||
_, err := jwtService.RefreshAccessToken(testToken)
|
_, _, err := jwtService.RefreshAccessTokenWithRotation(testToken)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected error for expired refresh token")
|
t.Error("Expected error for expired refresh token")
|
||||||
}
|
}
|
||||||
@@ -413,7 +427,6 @@ func TestJWTService_RefreshAccessToken(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("User_Not_Found", func(t *testing.T) {
|
t.Run("User_Not_Found", func(t *testing.T) {
|
||||||
|
|
||||||
refreshToken, err := jwtService.GenerateRefreshToken(user)
|
refreshToken, err := jwtService.GenerateRefreshToken(user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to generate refresh token: %v", err)
|
t.Fatalf("Failed to generate refresh token: %v", err)
|
||||||
@@ -421,7 +434,7 @@ func TestJWTService_RefreshAccessToken(t *testing.T) {
|
|||||||
|
|
||||||
delete(userRepo.users, user.ID)
|
delete(userRepo.users, user.ID)
|
||||||
|
|
||||||
_, err = jwtService.RefreshAccessToken(refreshToken)
|
_, _, err = jwtService.RefreshAccessTokenWithRotation(refreshToken)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected error for non-existent user")
|
t.Error("Expected error for non-existent user")
|
||||||
}
|
}
|
||||||
@@ -439,7 +452,7 @@ func TestJWTService_RefreshAccessToken(t *testing.T) {
|
|||||||
t.Fatalf("Failed to generate refresh token: %v", err)
|
t.Fatalf("Failed to generate refresh token: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = jwtService.RefreshAccessToken(refreshToken)
|
_, _, err = jwtService.RefreshAccessTokenWithRotation(refreshToken)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected error for locked user")
|
t.Error("Expected error for locked user")
|
||||||
}
|
}
|
||||||
@@ -937,7 +950,7 @@ func TestJWTService_Integration(t *testing.T) {
|
|||||||
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
|
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
newAccessToken, err := jwtService.RefreshAccessToken(refreshToken)
|
newAccessToken, rotatedRefreshToken, err := jwtService.RefreshAccessTokenWithRotation(refreshToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to refresh access token: %v", err)
|
t.Fatalf("Failed to refresh access token: %v", err)
|
||||||
}
|
}
|
||||||
@@ -950,12 +963,12 @@ func TestJWTService_Integration(t *testing.T) {
|
|||||||
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
|
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = jwtService.RevokeRefreshToken(refreshToken)
|
err = jwtService.RevokeRefreshToken(rotatedRefreshToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to revoke refresh token: %v", err)
|
t.Fatalf("Failed to revoke refresh token: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = jwtService.RefreshAccessToken(refreshToken)
|
_, _, err = jwtService.RefreshAccessTokenWithRotation(rotatedRefreshToken)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected error when using revoked refresh token")
|
t.Error("Expected error when using revoked refresh token")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ func (s *SessionService) issueAuthResult(user *database.User) (*AuthResult, erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SessionService) RefreshAccessToken(refreshToken string) (*AuthResult, error) {
|
func (s *SessionService) RefreshAccessToken(refreshToken string) (*AuthResult, error) {
|
||||||
accessToken, err := s.jwtService.RefreshAccessToken(refreshToken)
|
accessToken, newRefreshToken, err := s.jwtService.RefreshAccessTokenWithRotation(refreshToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -88,7 +88,7 @@ func (s *SessionService) RefreshAccessToken(refreshToken string) (*AuthResult, e
|
|||||||
|
|
||||||
return &AuthResult{
|
return &AuthResult{
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: newRefreshToken,
|
||||||
User: sanitizeUser(user),
|
User: sanitizeUser(user),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -398,12 +398,23 @@ func TestSessionService_RefreshAccessToken(t *testing.T) {
|
|||||||
if result.AccessToken == "" {
|
if result.AccessToken == "" {
|
||||||
t.Error("expected non-empty access token")
|
t.Error("expected non-empty access token")
|
||||||
}
|
}
|
||||||
if result.RefreshToken != refreshToken {
|
if result.RefreshToken == "" {
|
||||||
t.Errorf("expected refresh token to remain unchanged")
|
t.Error("expected non-empty refresh token")
|
||||||
|
}
|
||||||
|
if result.RefreshToken == refreshToken {
|
||||||
|
t.Errorf("expected refresh token to rotate")
|
||||||
}
|
}
|
||||||
if result.User == nil {
|
if result.User == nil {
|
||||||
t.Fatal("expected non-nil user")
|
t.Fatal("expected non-nil user")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = service.RefreshAccessToken(refreshToken)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when using rotated refresh token")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrRefreshTokenInvalid) {
|
||||||
|
t.Errorf("expected ErrRefreshTokenInvalid, got %v", err)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("invalid refresh token", func(t *testing.T) {
|
t.Run("invalid refresh token", func(t *testing.T) {
|
||||||
|
|||||||
@@ -3,51 +3,61 @@ package validation
|
|||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"unicode/utf8"
|
||||||
"goyco/internal/fuzz"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func FuzzValidateEmail(f *testing.F) {
|
func FuzzValidateEmail(f *testing.F) {
|
||||||
helper := fuzz.NewFuzzTestHelper()
|
runValidationFuzzTest(f, ValidateEmail)
|
||||||
helper.RunValidationFuzzTest(f, ValidateEmail)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func FuzzValidateUsername(f *testing.F) {
|
func FuzzValidateUsername(f *testing.F) {
|
||||||
helper := fuzz.NewFuzzTestHelper()
|
runValidationFuzzTest(f, ValidateUsername)
|
||||||
helper.RunValidationFuzzTest(f, ValidateUsername)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func FuzzValidatePassword(f *testing.F) {
|
func FuzzValidatePassword(f *testing.F) {
|
||||||
helper := fuzz.NewFuzzTestHelper()
|
runValidationFuzzTest(f, ValidatePassword)
|
||||||
helper.RunValidationFuzzTest(f, ValidatePassword)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func FuzzValidateURL(f *testing.F) {
|
func FuzzValidateURL(f *testing.F) {
|
||||||
helper := fuzz.NewFuzzTestHelper()
|
runValidationFuzzTest(f, ValidateURL)
|
||||||
helper.RunValidationFuzzTest(f, ValidateURL)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func FuzzValidateTitle(f *testing.F) {
|
func FuzzValidateTitle(f *testing.F) {
|
||||||
helper := fuzz.NewFuzzTestHelper()
|
runValidationFuzzTest(f, ValidateTitle)
|
||||||
helper.RunValidationFuzzTest(f, ValidateTitle)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func FuzzValidateContent(f *testing.F) {
|
func FuzzValidateContent(f *testing.F) {
|
||||||
helper := fuzz.NewFuzzTestHelper()
|
runValidationFuzzTest(f, ValidateContent)
|
||||||
helper.RunValidationFuzzTest(f, ValidateContent)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func FuzzValidateSearchQuery(f *testing.F) {
|
func FuzzValidateSearchQuery(f *testing.F) {
|
||||||
helper := fuzz.NewFuzzTestHelper()
|
runValidationFuzzTest(f, ValidateSearchQuery)
|
||||||
helper.RunValidationFuzzTest(f, ValidateSearchQuery)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func FuzzSanitizeString(f *testing.F) {
|
func FuzzSanitizeString(f *testing.F) {
|
||||||
helper := fuzz.NewFuzzTestHelper()
|
f.Add("test input")
|
||||||
helper.RunSanitizationFuzzTestWithValidation(f,
|
f.Fuzz(func(t *testing.T, input string) {
|
||||||
SanitizeString,
|
if !utf8.ValidString(input) {
|
||||||
func(result string) bool {
|
return
|
||||||
return !containsNullBytes(result)
|
}
|
||||||
|
result := SanitizeString(input)
|
||||||
|
if !utf8.ValidString(result) {
|
||||||
|
t.Fatal("Sanitized result contains invalid UTF-8")
|
||||||
|
}
|
||||||
|
if containsNullBytes(result) {
|
||||||
|
t.Fatal("Sanitized result contains null bytes")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func runValidationFuzzTest(f *testing.F, validateFunc func(string) error) {
|
||||||
|
f.Add("test input")
|
||||||
|
f.Fuzz(func(t *testing.T, input string) {
|
||||||
|
if !utf8.ValidString(input) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := validateFunc(input)
|
||||||
|
_ = err
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -175,6 +175,33 @@ type FieldValidationError struct {
|
|||||||
Message string
|
Message string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getFieldDisplayName(field reflect.StructField) string {
|
||||||
|
jsonTag := field.Tag.Get("json")
|
||||||
|
if jsonTag != "" && jsonTag != "-" {
|
||||||
|
if idx := strings.Index(jsonTag, ","); idx != -1 {
|
||||||
|
jsonTag = jsonTag[:idx]
|
||||||
|
}
|
||||||
|
if jsonTag != "" {
|
||||||
|
return jsonTag
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return camelCaseToWords(field.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func camelCaseToWords(s string) string {
|
||||||
|
if s == "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
var result strings.Builder
|
||||||
|
for i, r := range s {
|
||||||
|
if i > 0 && unicode.IsUpper(r) {
|
||||||
|
result.WriteRune(' ')
|
||||||
|
}
|
||||||
|
result.WriteRune(unicode.ToLower(r))
|
||||||
|
}
|
||||||
|
return result.String()
|
||||||
|
}
|
||||||
|
|
||||||
func ValidateStruct(s interface{}) error {
|
func ValidateStruct(s interface{}) error {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -232,9 +259,10 @@ func ValidateStruct(s interface{}) error {
|
|||||||
tagName = tag
|
tagName = tag
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := validateField(field.Name, fieldVal, tagName, param, omitempty); err != nil {
|
displayName := getFieldDisplayName(field)
|
||||||
|
if err := validateField(displayName, fieldVal, tagName, param, omitempty); err != nil {
|
||||||
errors = append(errors, FieldValidationError{
|
errors = append(errors, FieldValidationError{
|
||||||
Field: field.Name,
|
Field: displayName,
|
||||||
Tag: tagName,
|
Tag: tagName,
|
||||||
Param: param,
|
Param: param,
|
||||||
Message: err.Message,
|
Message: err.Message,
|
||||||
@@ -293,7 +321,7 @@ func validateField(fieldName string, fieldVal reflect.Value, tagName, param stri
|
|||||||
func isEmptyValue(v reflect.Value) bool {
|
func isEmptyValue(v reflect.Value) bool {
|
||||||
switch v.Kind() {
|
switch v.Kind() {
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
return v.String() == ""
|
return strings.TrimSpace(v.String()) == ""
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
return v.Int() == 0
|
return v.Int() == 0
|
||||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
@@ -301,7 +329,7 @@ func isEmptyValue(v reflect.Value) bool {
|
|||||||
case reflect.Float32, reflect.Float64:
|
case reflect.Float32, reflect.Float64:
|
||||||
return v.Float() == 0
|
return v.Float() == 0
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
return !v.Bool()
|
return false
|
||||||
case reflect.Pointer, reflect.Interface, reflect.Slice, reflect.Map:
|
case reflect.Pointer, reflect.Interface, reflect.Slice, reflect.Map:
|
||||||
return v.IsNil()
|
return v.IsNil()
|
||||||
default:
|
default:
|
||||||
@@ -317,7 +345,8 @@ func validateMin(fieldName string, v reflect.Value, param string) *ValidationErr
|
|||||||
|
|
||||||
switch v.Kind() {
|
switch v.Kind() {
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
if len(v.String()) < min {
|
s := strings.TrimSpace(v.String())
|
||||||
|
if len([]rune(s)) < min {
|
||||||
return &ValidationError{Field: fieldName, Message: fmt.Sprintf("%s must be at least %d characters", fieldName, min)}
|
return &ValidationError{Field: fieldName, Message: fmt.Sprintf("%s must be at least %d characters", fieldName, min)}
|
||||||
}
|
}
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
@@ -325,6 +354,9 @@ func validateMin(fieldName string, v reflect.Value, param string) *ValidationErr
|
|||||||
return &ValidationError{Field: fieldName, Message: fmt.Sprintf("%s must be at least %d", fieldName, min)}
|
return &ValidationError{Field: fieldName, Message: fmt.Sprintf("%s must be at least %d", fieldName, min)}
|
||||||
}
|
}
|
||||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
if min < 0 {
|
||||||
|
return &ValidationError{Field: fieldName, Message: fieldName + " has invalid min parameter (must be non-negative)"}
|
||||||
|
}
|
||||||
if v.Uint() < uint64(min) {
|
if v.Uint() < uint64(min) {
|
||||||
return &ValidationError{Field: fieldName, Message: fmt.Sprintf("%s must be at least %d", fieldName, min)}
|
return &ValidationError{Field: fieldName, Message: fmt.Sprintf("%s must be at least %d", fieldName, min)}
|
||||||
}
|
}
|
||||||
@@ -341,7 +373,7 @@ func validateMax(fieldName string, v reflect.Value, param string) *ValidationErr
|
|||||||
|
|
||||||
switch v.Kind() {
|
switch v.Kind() {
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
if len(v.String()) > max {
|
if len([]rune(v.String())) > max {
|
||||||
return &ValidationError{Field: fieldName, Message: fmt.Sprintf("%s must be at most %d characters", fieldName, max)}
|
return &ValidationError{Field: fieldName, Message: fmt.Sprintf("%s must be at most %d characters", fieldName, max)}
|
||||||
}
|
}
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
@@ -349,6 +381,9 @@ func validateMax(fieldName string, v reflect.Value, param string) *ValidationErr
|
|||||||
return &ValidationError{Field: fieldName, Message: fmt.Sprintf("%s must be at most %d", fieldName, max)}
|
return &ValidationError{Field: fieldName, Message: fmt.Sprintf("%s must be at most %d", fieldName, max)}
|
||||||
}
|
}
|
||||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
if max < 0 {
|
||||||
|
return &ValidationError{Field: fieldName, Message: fieldName + " has invalid max parameter (must be non-negative)"}
|
||||||
|
}
|
||||||
if v.Uint() > uint64(max) {
|
if v.Uint() > uint64(max) {
|
||||||
return &ValidationError{Field: fieldName, Message: fmt.Sprintf("%s must be at most %d", fieldName, max)}
|
return &ValidationError{Field: fieldName, Message: fmt.Sprintf("%s must be at most %d", fieldName, max)}
|
||||||
}
|
}
|
||||||
@@ -402,12 +437,15 @@ func validateOneOf(fieldName string, v reflect.Value, param string) *ValidationE
|
|||||||
|
|
||||||
value := v.String()
|
value := v.String()
|
||||||
allowedValues := strings.Split(param, " ")
|
allowedValues := strings.Split(param, " ")
|
||||||
|
allowedMap := make(map[string]bool, len(allowedValues))
|
||||||
|
|
||||||
for _, allowed := range allowedValues {
|
for _, allowed := range allowedValues {
|
||||||
if value == allowed {
|
allowedMap[allowed] = true
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !allowedMap[value] {
|
||||||
return &ValidationError{Field: fieldName, Message: fieldName + " must be one of: " + param}
|
return &ValidationError{Field: fieldName, Message: fieldName + " must be one of: " + param}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -250,12 +250,12 @@ func TestSanitizeString(t *testing.T) {
|
|||||||
|
|
||||||
func TestValidateStruct(t *testing.T) {
|
func TestValidateStruct(t *testing.T) {
|
||||||
type TestStruct struct {
|
type TestStruct struct {
|
||||||
Username string `validate:"required,min=3,max=20"`
|
Username string `json:"username" validate:"required,min=3,max=20"`
|
||||||
Email string `validate:"required,email"`
|
Email string `json:"email" validate:"required,email"`
|
||||||
Age int `validate:"min=18,max=120"`
|
Age int `json:"age" validate:"min=18,max=120"`
|
||||||
URL string `validate:"url"`
|
URL string `json:"url" validate:"url"`
|
||||||
Status string `validate:"oneof=active inactive pending"`
|
Status string `json:"status" validate:"oneof=active inactive pending"`
|
||||||
Optional string `validate:"omitempty,min=1"`
|
Optional string `json:"optional" validate:"omitempty,min=1"`
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("valid struct", func(t *testing.T) {
|
t.Run("valid struct", func(t *testing.T) {
|
||||||
@@ -287,6 +287,9 @@ func TestValidateStruct(t *testing.T) {
|
|||||||
if len(structErr.Errors) == 0 {
|
if len(structErr.Errors) == 0 {
|
||||||
t.Error("Expected validation errors, got none")
|
t.Error("Expected validation errors, got none")
|
||||||
}
|
}
|
||||||
|
if structErr.Errors[0].Message != "username is required" {
|
||||||
|
t.Errorf("Expected JSON tag name in error, got %q", structErr.Errors[0].Message)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -318,6 +321,20 @@ func TestValidateStruct(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("whitespace required field", func(t *testing.T) {
|
||||||
|
s := TestStruct{
|
||||||
|
Username: " ",
|
||||||
|
Email: "test@example.com",
|
||||||
|
Age: 25,
|
||||||
|
URL: "https://example.com",
|
||||||
|
Status: "active",
|
||||||
|
}
|
||||||
|
err := ValidateStruct(s)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("ValidateStruct() expected error, got nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("invalid max", func(t *testing.T) {
|
t.Run("invalid max", func(t *testing.T) {
|
||||||
s := TestStruct{
|
s := TestStruct{
|
||||||
Username: strings.Repeat("a", 21),
|
Username: strings.Repeat("a", 21),
|
||||||
|
|||||||
@@ -1,3 +1,7 @@
|
|||||||
package version
|
package version
|
||||||
|
|
||||||
const Version = "0.1.0"
|
const version = "0.1.1"
|
||||||
|
|
||||||
|
func GetVersion() string {
|
||||||
|
return version
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,39 +8,39 @@ import (
|
|||||||
func TestVersionSemver(t *testing.T) {
|
func TestVersionSemver(t *testing.T) {
|
||||||
semverRegex := regexp.MustCompile(`^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$`)
|
semverRegex := regexp.MustCompile(`^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$`)
|
||||||
|
|
||||||
if !semverRegex.MatchString(Version) {
|
if !semverRegex.MatchString(GetVersion()) {
|
||||||
t.Errorf("Version %q does not follow semantic versioning format (MAJOR.MINOR.PATCH[-PRERELEASE][+BUILD])", Version)
|
t.Errorf("Version %q does not follow semantic versioning format (MAJOR.MINOR.PATCH[-PRERELEASE][+BUILD])", GetVersion())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestVersionSemverFlexible(t *testing.T) {
|
func TestVersionSemverFlexible(t *testing.T) {
|
||||||
flexibleSemverRegex := regexp.MustCompile(`^(0|[1-9]\d*)\.(0|[1-9]\d*)(?:\.(0|[1-9]\d*))?(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$`)
|
flexibleSemverRegex := regexp.MustCompile(`^(0|[1-9]\d*)\.(0|[1-9]\d*)(?:\.(0|[1-9]\d*))?(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$`)
|
||||||
|
|
||||||
if !flexibleSemverRegex.MatchString(Version) {
|
if !flexibleSemverRegex.MatchString(GetVersion()) {
|
||||||
t.Errorf("Version %q does not follow semantic versioning format (MAJOR.MINOR[.PATCH][-PRERELEASE][+BUILD])", Version)
|
t.Errorf("Version %q does not follow semantic versioning format (MAJOR.MINOR[.PATCH][-PRERELEASE][+BUILD])", GetVersion())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestVersionNotEmpty(t *testing.T) {
|
func TestVersionNotEmpty(t *testing.T) {
|
||||||
if Version == "" {
|
if GetVersion() == "" {
|
||||||
t.Error("Version should not be empty")
|
t.Error("Version should not be empty")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestVersionFormat(t *testing.T) {
|
func TestVersionFormat(t *testing.T) {
|
||||||
if !regexp.MustCompile(`\d+\.\d+`).MatchString(Version) {
|
if !regexp.MustCompile(`\d+\.\d+`).MatchString(GetVersion()) {
|
||||||
t.Errorf("Version %q should contain at least MAJOR.MINOR format", Version)
|
t.Errorf("Version %q should contain at least MAJOR.MINOR format", GetVersion())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestVersionStartsWithNumber(t *testing.T) {
|
func TestVersionStartsWithNumber(t *testing.T) {
|
||||||
if !regexp.MustCompile(`^\d+`).MatchString(Version) {
|
if !regexp.MustCompile(`^\d+`).MatchString(GetVersion()) {
|
||||||
t.Errorf("Version %q should start with a number", Version)
|
t.Errorf("Version %q should start with a number", GetVersion())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestVersionNoLeadingZeros(t *testing.T) {
|
func TestVersionNoLeadingZeros(t *testing.T) {
|
||||||
parts := regexp.MustCompile(`^(\d+)\.(\d+)`).FindStringSubmatch(Version)
|
parts := regexp.MustCompile(`^(\d+)\.(\d+)`).FindStringSubmatch(GetVersion())
|
||||||
if len(parts) >= 3 {
|
if len(parts) >= 3 {
|
||||||
major := parts[1]
|
major := parts[1]
|
||||||
minor := parts[2]
|
minor := parts[2]
|
||||||
|
|||||||
Reference in New Issue
Block a user