Compare commits
16 Commits
457b5c88e2
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 07c6b89525 | |||
| 817205d42f | |||
| 199ac143a4 | |||
| aa7e259ed0 | |||
| 4587609e17 | |||
| 33da6503e3 | |||
| cafc44ed77 | |||
| 1480135e75 | |||
| 02a764c736 | |||
| 6834ad7764 | |||
| dcf054046f | |||
| d2a788933d | |||
| 18be3950dc | |||
| f9cb140e95 | |||
| 86d4835ccf | |||
| feddb2ed43 |
38
README.md
38
README.md
@@ -354,44 +354,6 @@ It will start the application in development mode. You can also run it as a daem
|
|||||||
|
|
||||||
Then, use `./bin/goyco` to manage the application and notably to seed the database with sample data.
|
Then, use `./bin/goyco` to manage the application and notably to seed the database with sample data.
|
||||||
|
|
||||||
### Project Structure
|
|
||||||
|
|
||||||
```bash
|
|
||||||
goyco/
|
|
||||||
├── bin/ # Compiled binaries (created after build)
|
|
||||||
├── cmd/
|
|
||||||
│ └── goyco/ # Main CLI application entrypoint
|
|
||||||
├── docker/ # Docker Compose & related files
|
|
||||||
├── docs/ # Documentation and API specs
|
|
||||||
├── internal/
|
|
||||||
│ ├── config/ # Configuration management
|
|
||||||
│ ├── database/ # Database models and access
|
|
||||||
│ ├── dto/ # Data Transfer Objects (DTOs)
|
|
||||||
│ ├── e2e/ # End-to-end tests
|
|
||||||
│ ├── fuzz/ # Fuzz tests
|
|
||||||
│ ├── handlers/ # HTTP handlers
|
|
||||||
│ ├── integration/ # Integration tests
|
|
||||||
│ ├── middleware/ # HTTP middleware
|
|
||||||
│ ├── repositories/ # Data access layer
|
|
||||||
│ ├── security/ # Security and auth logic
|
|
||||||
│ ├── server/ # HTTP server implementation
|
|
||||||
│ ├── services/ # Business logic
|
|
||||||
│ ├── static/ # Static web assets
|
|
||||||
│ ├── templates/ # HTML templates
|
|
||||||
│ ├── testutils/ # Test helpers/utilities
|
|
||||||
│ ├── validation/ # Input validation
|
|
||||||
│ └── version/ # Version information
|
|
||||||
├── scripts/ # Utility/maintenance scripts
|
|
||||||
├── services/
|
|
||||||
│ └── goyco.service # Systemd service unit example
|
|
||||||
├── .env.example # Environment variable example
|
|
||||||
├── AUTHORS # Authors file
|
|
||||||
├── Dockerfile # Docker build file
|
|
||||||
├── LICENSE # License file
|
|
||||||
├── Makefile # Project build/test targets
|
|
||||||
└── README.md # This file
|
|
||||||
```
|
|
||||||
|
|
||||||
### Testing
|
### Testing
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -8,9 +8,10 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"goyco/internal/config"
|
"goyco/internal/config"
|
||||||
"goyco/internal/database"
|
"goyco/internal/database"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrHelpRequested = errors.New("help requested")
|
var ErrHelpRequested = errors.New("help requested")
|
||||||
@@ -118,26 +119,6 @@ func outputJSON(v interface{}) error {
|
|||||||
return encoder.Encode(v)
|
return encoder.Encode(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func outputMessage(message string, args ...interface{}) {
|
|
||||||
if IsJSONOutput() {
|
|
||||||
outputJSON(map[string]interface{}{
|
|
||||||
"message": fmt.Sprintf(message, args...),
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
fmt.Printf(message+"\n", args...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func outputError(message string, args ...interface{}) {
|
|
||||||
if IsJSONOutput() {
|
|
||||||
outputJSON(map[string]interface{}{
|
|
||||||
"error": fmt.Sprintf(message, args...),
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(os.Stderr, message+"\n", args...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func outputWarning(message string, args ...interface{}) {
|
func outputWarning(message string, args ...interface{}) {
|
||||||
if IsJSONOutput() {
|
if IsJSONOutput() {
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]interface{}{
|
||||||
|
|||||||
@@ -5,9 +5,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"goyco/internal/config"
|
"goyco/internal/config"
|
||||||
"goyco/internal/database"
|
"goyco/internal/database"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func HandleMigrateCommand(cfg *config.Config, name string, args []string) error {
|
func HandleMigrateCommand(cfg *config.Config, name string, args []string) error {
|
||||||
@@ -37,7 +38,7 @@ func runMigrateCommand(db *gorm.DB) error {
|
|||||||
return fmt.Errorf("run migrations: %w", err)
|
return fmt.Errorf("run migrations: %w", err)
|
||||||
}
|
}
|
||||||
if IsJSONOutput() {
|
if IsJSONOutput() {
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"action": "migrations_applied",
|
"action": "migrations_applied",
|
||||||
"status": "success",
|
"status": "success",
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -261,6 +261,7 @@ func processItemsInParallelNoResult[T any](
|
|||||||
) error {
|
) error {
|
||||||
count := len(items)
|
count := len(items)
|
||||||
errors := make(chan error, count)
|
errors := make(chan error, count)
|
||||||
|
completions := make(chan struct{}, count)
|
||||||
|
|
||||||
semaphore := make(chan struct{}, maxWorkers)
|
semaphore := make(chan struct{}, maxWorkers)
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
@@ -288,20 +289,45 @@ func processItemsInParallelNoResult[T any](
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if progress != nil {
|
completions <- struct{}{}
|
||||||
progress.Update(index + 1)
|
|
||||||
}
|
|
||||||
}(i, item)
|
}(i, item)
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
close(errors)
|
close(errors)
|
||||||
|
close(completions)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for err := range errors {
|
completed := 0
|
||||||
if err != nil {
|
firstError := make(chan error, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for err := range errors {
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
case firstError <- err:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for completed < count {
|
||||||
|
select {
|
||||||
|
case _, ok := <-completions:
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
completed++
|
||||||
|
if progress != nil {
|
||||||
|
progress.Update(completed)
|
||||||
|
}
|
||||||
|
case err := <-firstError:
|
||||||
return err
|
return err
|
||||||
|
case <-ctx.Done():
|
||||||
|
return fmt.Errorf("timeout: %w", ctx.Err())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,15 +2,15 @@ package commands_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
"goyco/cmd/goyco/commands"
|
"goyco/cmd/goyco/commands"
|
||||||
"goyco/internal/database"
|
"goyco/internal/database"
|
||||||
"goyco/internal/repositories"
|
"goyco/internal/repositories"
|
||||||
"goyco/internal/testutils"
|
"goyco/internal/testutils"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
|
func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
|
||||||
@@ -25,7 +25,7 @@ func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
|
|||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "creates users with deterministic fields",
|
name: "creates users with required fields",
|
||||||
count: successCount,
|
count: successCount,
|
||||||
repoFactory: func() repositories.UserRepository {
|
repoFactory: func() repositories.UserRepository {
|
||||||
base := testutils.NewMockUserRepository()
|
base := testutils.NewMockUserRepository()
|
||||||
@@ -37,14 +37,24 @@ func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
|
|||||||
if len(got) != successCount {
|
if len(got) != successCount {
|
||||||
t.Fatalf("expected %d users, got %d", successCount, len(got))
|
t.Fatalf("expected %d users, got %d", successCount, len(got))
|
||||||
}
|
}
|
||||||
|
usernames := make(map[string]bool)
|
||||||
for i, user := range got {
|
for i, user := range got {
|
||||||
expectedUsername := fmt.Sprintf("user_%d", i+1)
|
if user.Username == "" {
|
||||||
expectedEmail := fmt.Sprintf("user_%d@goyco.local", i+1)
|
t.Errorf("user %d expected non-empty username", i)
|
||||||
if user.Username != expectedUsername {
|
|
||||||
t.Errorf("user %d username mismatch: got %q want %q", i, user.Username, expectedUsername)
|
|
||||||
}
|
}
|
||||||
if user.Email != expectedEmail {
|
if len(user.Username) < 6 || user.Username[:5] != "user_" {
|
||||||
t.Errorf("user %d email mismatch: got %q want %q", i, user.Email, expectedEmail)
|
t.Errorf("user %d username should start with 'user_', got %q", i, user.Username)
|
||||||
|
}
|
||||||
|
if usernames[user.Username] {
|
||||||
|
t.Errorf("user %d duplicate username: %q", i, user.Username)
|
||||||
|
}
|
||||||
|
usernames[user.Username] = true
|
||||||
|
|
||||||
|
if user.Email == "" {
|
||||||
|
t.Errorf("user %d expected non-empty email", i)
|
||||||
|
}
|
||||||
|
if len(user.Email) < 20 || user.Email[:5] != "user_" || user.Email[len(user.Email)-12:] != "@goyco.local" {
|
||||||
|
t.Errorf("user %d email should match pattern 'user_*@goyco.local', got %q", i, user.Email)
|
||||||
}
|
}
|
||||||
if !user.EmailVerified {
|
if !user.EmailVerified {
|
||||||
t.Errorf("user %d expected EmailVerified to be true", i)
|
t.Errorf("user %d expected EmailVerified to be true", i)
|
||||||
@@ -83,6 +93,11 @@ func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
repo := tt.repoFactory()
|
repo := tt.repoFactory()
|
||||||
p := commands.NewParallelProcessor()
|
p := commands.NewParallelProcessor()
|
||||||
|
passwordHash, err := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to generate password hash: %v", err)
|
||||||
|
}
|
||||||
|
p.SetPasswordHash(string(passwordHash))
|
||||||
got, gotErr := p.CreateUsersInParallel(repo, tt.count, tt.progress)
|
got, gotErr := p.CreateUsersInParallel(repo, tt.count, tt.progress)
|
||||||
if gotErr != nil {
|
if gotErr != nil {
|
||||||
if !tt.wantErr {
|
if !tt.wantErr {
|
||||||
|
|||||||
@@ -35,17 +35,6 @@ func initSeedRand() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateRandomIdentifier() string {
|
|
||||||
initSeedRand()
|
|
||||||
const length = 12
|
|
||||||
const chars = "abcdefghijklmnopqrstuvwxyz0123456789"
|
|
||||||
identifier := make([]byte, length)
|
|
||||||
for i := range identifier {
|
|
||||||
identifier[i] = chars[seedRandSource.Intn(len(chars))]
|
|
||||||
}
|
|
||||||
return string(identifier)
|
|
||||||
}
|
|
||||||
|
|
||||||
func HandleSeedCommand(cfg *config.Config, name string, args []string) error {
|
func HandleSeedCommand(cfg *config.Config, name string, args []string) error {
|
||||||
fs := newFlagSet(name, printSeedUsage)
|
fs := newFlagSet(name, printSeedUsage)
|
||||||
if err := parseCommand(fs, args, name); err != nil {
|
if err := parseCommand(fs, args, name); err != nil {
|
||||||
@@ -236,44 +225,28 @@ func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.Po
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func findExistingSeedUser(userRepo repositories.UserRepository) (*database.User, error) {
|
const (
|
||||||
user, err := userRepo.GetByUsernamePrefix("seed_admin_")
|
seedUsername = "seed_admin"
|
||||||
if err != nil {
|
seedEmail = "seed_admin@goyco.local"
|
||||||
return nil, fmt.Errorf("no existing seed user found")
|
)
|
||||||
}
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ensureSeedUser(userRepo repositories.UserRepository, passwordHash string) (*database.User, error) {
|
func ensureSeedUser(userRepo repositories.UserRepository, passwordHash string) (*database.User, error) {
|
||||||
existingUser, err := findExistingSeedUser(userRepo)
|
if user, err := userRepo.GetByUsername(seedUsername); err == nil {
|
||||||
if err == nil && existingUser != nil {
|
|
||||||
return existingUser, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
randomID := generateRandomIdentifier()
|
|
||||||
seedUsername := fmt.Sprintf("seed_admin_%s", randomID)
|
|
||||||
seedEmail := fmt.Sprintf("seed_admin_%s@goyco.local", randomID)
|
|
||||||
|
|
||||||
const maxRetries = 10
|
|
||||||
for range maxRetries {
|
|
||||||
user := &database.User{
|
|
||||||
Username: seedUsername,
|
|
||||||
Email: seedEmail,
|
|
||||||
Password: passwordHash,
|
|
||||||
EmailVerified: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := userRepo.Create(user); err != nil {
|
|
||||||
randomID = generateRandomIdentifier()
|
|
||||||
seedUsername = fmt.Sprintf("seed_admin_%s", randomID)
|
|
||||||
seedEmail = fmt.Sprintf("seed_admin_%s@goyco.local", randomID)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("failed to create seed user after %d attempts", maxRetries)
|
user := &database.User{
|
||||||
|
Username: seedUsername,
|
||||||
|
Email: seedEmail,
|
||||||
|
Password: passwordHash,
|
||||||
|
EmailVerified: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := userRepo.Create(user); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create seed user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getVoteCounts(voteRepo repositories.VoteRepository, postID uint) (int, int, error) {
|
func getVoteCounts(voteRepo repositories.VoteRepository, postID uint) (int, int, error) {
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ func TestSeedCommand(t *testing.T) {
|
|||||||
var seedUser *database.User
|
var seedUser *database.User
|
||||||
regularUserCount := 0
|
regularUserCount := 0
|
||||||
for i := range users {
|
for i := range users {
|
||||||
if strings.HasPrefix(users[i].Username, "seed_admin_") {
|
if users[i].Username == "seed_admin" {
|
||||||
seedUserCount++
|
seedUserCount++
|
||||||
seedUser = &users[i]
|
seedUser = &users[i]
|
||||||
} else if strings.HasPrefix(users[i].Username, "user_") {
|
} else if strings.HasPrefix(users[i].Username, "user_") {
|
||||||
@@ -63,12 +63,12 @@ func TestSeedCommand(t *testing.T) {
|
|||||||
t.Fatal("Expected seed user to be created")
|
t.Fatal("Expected seed user to be created")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.HasPrefix(seedUser.Username, "seed_admin_") {
|
if seedUser.Username != "seed_admin" {
|
||||||
t.Errorf("Expected username to start with 'seed_admin_', got '%s'", seedUser.Username)
|
t.Errorf("Expected username to be 'seed_admin', got '%s'", seedUser.Username)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.HasPrefix(seedUser.Email, "seed_admin_") || !strings.HasSuffix(seedUser.Email, "@goyco.local") {
|
if seedUser.Email != "seed_admin@goyco.local" {
|
||||||
t.Errorf("Expected email to start with 'seed_admin_' and end with '@goyco.local', got '%s'", seedUser.Email)
|
t.Errorf("Expected email to be 'seed_admin@goyco.local', got '%s'", seedUser.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !seedUser.EmailVerified {
|
if !seedUser.EmailVerified {
|
||||||
@@ -302,13 +302,13 @@ func TestSeedCommandIdempotency(t *testing.T) {
|
|||||||
|
|
||||||
seedUserCount := 0
|
seedUserCount := 0
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
if strings.HasPrefix(user.Username, "seed_admin_") {
|
if user.Username == "seed_admin" {
|
||||||
seedUserCount++
|
seedUserCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if seedUserCount < 1 {
|
if seedUserCount != 1 {
|
||||||
t.Errorf("Expected at least 1 seed user, got %d", seedUserCount)
|
t.Errorf("Expected exactly 1 seed user, got %d", seedUserCount)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -387,7 +387,7 @@ func TestSeedCommandIdempotency(t *testing.T) {
|
|||||||
|
|
||||||
func findSeedUser(users []database.User) *database.User {
|
func findSeedUser(users []database.User) *database.User {
|
||||||
for i := range users {
|
for i := range users {
|
||||||
if strings.HasPrefix(users[i].Username, "seed_admin_") {
|
if users[i].Username == "seed_admin" {
|
||||||
return &users[i]
|
return &users[i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -476,3 +476,58 @@ func TestSeedCommandTransactionRollback(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEnsureSeedUser(t *testing.T) {
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to connect to database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.AutoMigrate(&database.User{}); err != nil {
|
||||||
|
t.Fatalf("Failed to migrate database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
userRepo := repositories.NewUserRepository(db)
|
||||||
|
passwordHash := "test_password_hash"
|
||||||
|
|
||||||
|
firstUser, err := ensureSeedUser(userRepo, passwordHash)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create seed user: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if firstUser.Username != "seed_admin" || firstUser.Email != "seed_admin@goyco.local" || firstUser.Password != passwordHash || !firstUser.EmailVerified {
|
||||||
|
t.Errorf("Invalid seed user: username=%s, email=%s, password matches=%v, emailVerified=%v",
|
||||||
|
firstUser.Username, firstUser.Email, firstUser.Password == passwordHash, firstUser.EmailVerified)
|
||||||
|
}
|
||||||
|
|
||||||
|
secondUser, err := ensureSeedUser(userRepo, "different_password_hash")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to reuse seed user: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if firstUser.ID != secondUser.ID {
|
||||||
|
t.Errorf("Expected same user to be reused (ID %d), got different user (ID %d)", firstUser.ID, secondUser.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
if _, err := ensureSeedUser(userRepo, passwordHash); err != nil {
|
||||||
|
t.Fatalf("Call %d failed: %v", i+1, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
users, err := userRepo.GetAll(100, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get users: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
seedUserCount := 0
|
||||||
|
for _, user := range users {
|
||||||
|
if user.Username == "seed_admin" {
|
||||||
|
seedUserCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if seedUserCount != 1 {
|
||||||
|
t.Errorf("Expected exactly 1 seed user, got %d", seedUserCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package server
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"goyco/internal/config"
|
"goyco/internal/config"
|
||||||
@@ -105,9 +106,9 @@ func defaultRateLimitConfig() config.RateLimitConfig {
|
|||||||
return testutils.AppTestConfig.RateLimit
|
return testutils.AppTestConfig.RateLimit
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAPIRootRouting(t *testing.T) {
|
func createDefaultRouterConfig() RouterConfig {
|
||||||
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
|
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
|
||||||
router := NewRouter(RouterConfig{
|
return RouterConfig{
|
||||||
APIHandler: apiHandler,
|
APIHandler: apiHandler,
|
||||||
AuthHandler: authHandler,
|
AuthHandler: authHandler,
|
||||||
PostHandler: postHandler,
|
PostHandler: postHandler,
|
||||||
@@ -115,7 +116,15 @@ func TestAPIRootRouting(t *testing.T) {
|
|||||||
UserHandler: userHandler,
|
UserHandler: userHandler,
|
||||||
AuthService: authService,
|
AuthService: authService,
|
||||||
RateLimitConfig: defaultRateLimitConfig(),
|
RateLimitConfig: defaultRateLimitConfig(),
|
||||||
})
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestRouter(cfg RouterConfig) http.Handler {
|
||||||
|
return NewRouter(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIRootRouting(t *testing.T) {
|
||||||
|
router := createTestRouter(createDefaultRouterConfig())
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -141,23 +150,23 @@ func TestAPIRootRouting(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProtectedRoutesRequireAuth(t *testing.T) {
|
func TestProtectedRoutesRequireAuth(t *testing.T) {
|
||||||
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
|
router := createTestRouter(createDefaultRouterConfig())
|
||||||
router := NewRouter(RouterConfig{
|
|
||||||
APIHandler: apiHandler,
|
|
||||||
AuthHandler: authHandler,
|
|
||||||
PostHandler: postHandler,
|
|
||||||
VoteHandler: voteHandler,
|
|
||||||
UserHandler: userHandler,
|
|
||||||
AuthService: authService,
|
|
||||||
RateLimitConfig: defaultRateLimitConfig(),
|
|
||||||
})
|
|
||||||
|
|
||||||
protectedRoutes := []struct {
|
protectedRoutes := []struct {
|
||||||
method string
|
method string
|
||||||
path string
|
path string
|
||||||
}{
|
}{
|
||||||
{http.MethodGet, "/api/auth/me"},
|
{http.MethodGet, "/api/auth/me"},
|
||||||
|
{http.MethodPost, "/api/auth/logout"},
|
||||||
|
{http.MethodPost, "/api/auth/revoke"},
|
||||||
|
{http.MethodPost, "/api/auth/revoke-all"},
|
||||||
|
{http.MethodPut, "/api/auth/email"},
|
||||||
|
{http.MethodPut, "/api/auth/username"},
|
||||||
|
{http.MethodPut, "/api/auth/password"},
|
||||||
|
{http.MethodDelete, "/api/auth/account"},
|
||||||
{http.MethodPost, "/api/posts"},
|
{http.MethodPost, "/api/posts"},
|
||||||
|
{http.MethodPut, "/api/posts/1"},
|
||||||
|
{http.MethodDelete, "/api/posts/1"},
|
||||||
{http.MethodPost, "/api/posts/1/vote"},
|
{http.MethodPost, "/api/posts/1/vote"},
|
||||||
{http.MethodDelete, "/api/posts/1/vote"},
|
{http.MethodDelete, "/api/posts/1/vote"},
|
||||||
{http.MethodGet, "/api/posts/1/vote"},
|
{http.MethodGet, "/api/posts/1/vote"},
|
||||||
@@ -183,17 +192,9 @@ func TestProtectedRoutesRequireAuth(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterWithDebugMode(t *testing.T) {
|
func TestRouterWithDebugMode(t *testing.T) {
|
||||||
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
|
cfg := createDefaultRouterConfig()
|
||||||
router := NewRouter(RouterConfig{
|
cfg.Debug = true
|
||||||
Debug: true,
|
router := createTestRouter(cfg)
|
||||||
APIHandler: apiHandler,
|
|
||||||
AuthHandler: authHandler,
|
|
||||||
PostHandler: postHandler,
|
|
||||||
VoteHandler: voteHandler,
|
|
||||||
UserHandler: userHandler,
|
|
||||||
AuthService: authService,
|
|
||||||
RateLimitConfig: defaultRateLimitConfig(),
|
|
||||||
})
|
|
||||||
|
|
||||||
request := httptest.NewRequest(http.MethodGet, "/api", nil)
|
request := httptest.NewRequest(http.MethodGet, "/api", nil)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -206,16 +207,9 @@ func TestRouterWithDebugMode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterWithCacheDisabled(t *testing.T) {
|
func TestRouterWithCacheDisabled(t *testing.T) {
|
||||||
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
|
cfg := createDefaultRouterConfig()
|
||||||
router := NewRouter(RouterConfig{
|
cfg.DisableCache = true
|
||||||
DisableCache: true,
|
router := createTestRouter(cfg)
|
||||||
APIHandler: apiHandler,
|
|
||||||
AuthHandler: authHandler,
|
|
||||||
PostHandler: postHandler,
|
|
||||||
VoteHandler: voteHandler,
|
|
||||||
UserHandler: userHandler,
|
|
||||||
AuthService: authService,
|
|
||||||
})
|
|
||||||
|
|
||||||
request := httptest.NewRequest(http.MethodGet, "/api", nil)
|
request := httptest.NewRequest(http.MethodGet, "/api", nil)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -228,17 +222,9 @@ func TestRouterWithCacheDisabled(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterWithCompressionDisabled(t *testing.T) {
|
func TestRouterWithCompressionDisabled(t *testing.T) {
|
||||||
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
|
cfg := createDefaultRouterConfig()
|
||||||
router := NewRouter(RouterConfig{
|
cfg.DisableCompression = true
|
||||||
DisableCompression: true,
|
router := createTestRouter(cfg)
|
||||||
APIHandler: apiHandler,
|
|
||||||
AuthHandler: authHandler,
|
|
||||||
PostHandler: postHandler,
|
|
||||||
VoteHandler: voteHandler,
|
|
||||||
UserHandler: userHandler,
|
|
||||||
AuthService: authService,
|
|
||||||
RateLimitConfig: defaultRateLimitConfig(),
|
|
||||||
})
|
|
||||||
|
|
||||||
request := httptest.NewRequest(http.MethodGet, "/api", nil)
|
request := httptest.NewRequest(http.MethodGet, "/api", nil)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -251,19 +237,9 @@ func TestRouterWithCompressionDisabled(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterWithCustomDBMonitor(t *testing.T) {
|
func TestRouterWithCustomDBMonitor(t *testing.T) {
|
||||||
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
|
cfg := createDefaultRouterConfig()
|
||||||
customDBMonitor := middleware.NewInMemoryDBMonitor()
|
cfg.DBMonitor = middleware.NewInMemoryDBMonitor()
|
||||||
|
router := createTestRouter(cfg)
|
||||||
router := NewRouter(RouterConfig{
|
|
||||||
DBMonitor: customDBMonitor,
|
|
||||||
APIHandler: apiHandler,
|
|
||||||
AuthHandler: authHandler,
|
|
||||||
PostHandler: postHandler,
|
|
||||||
VoteHandler: voteHandler,
|
|
||||||
UserHandler: userHandler,
|
|
||||||
AuthService: authService,
|
|
||||||
RateLimitConfig: defaultRateLimitConfig(),
|
|
||||||
})
|
|
||||||
|
|
||||||
request := httptest.NewRequest(http.MethodGet, "/api", nil)
|
request := httptest.NewRequest(http.MethodGet, "/api", nil)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -296,18 +272,9 @@ func TestRouterWithPageHandler(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterWithStaticDir(t *testing.T) {
|
func TestRouterWithStaticDir(t *testing.T) {
|
||||||
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
|
cfg := createDefaultRouterConfig()
|
||||||
|
cfg.StaticDir = "/custom/static/path"
|
||||||
router := NewRouter(RouterConfig{
|
router := createTestRouter(cfg)
|
||||||
StaticDir: "/custom/static/path",
|
|
||||||
APIHandler: apiHandler,
|
|
||||||
AuthHandler: authHandler,
|
|
||||||
PostHandler: postHandler,
|
|
||||||
VoteHandler: voteHandler,
|
|
||||||
UserHandler: userHandler,
|
|
||||||
AuthService: authService,
|
|
||||||
RateLimitConfig: defaultRateLimitConfig(),
|
|
||||||
})
|
|
||||||
|
|
||||||
request := httptest.NewRequest(http.MethodGet, "/api", nil)
|
request := httptest.NewRequest(http.MethodGet, "/api", nil)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -320,18 +287,9 @@ func TestRouterWithStaticDir(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterWithEmptyStaticDir(t *testing.T) {
|
func TestRouterWithEmptyStaticDir(t *testing.T) {
|
||||||
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
|
cfg := createDefaultRouterConfig()
|
||||||
|
cfg.StaticDir = ""
|
||||||
router := NewRouter(RouterConfig{
|
router := createTestRouter(cfg)
|
||||||
StaticDir: "",
|
|
||||||
APIHandler: apiHandler,
|
|
||||||
AuthHandler: authHandler,
|
|
||||||
PostHandler: postHandler,
|
|
||||||
VoteHandler: voteHandler,
|
|
||||||
UserHandler: userHandler,
|
|
||||||
AuthService: authService,
|
|
||||||
RateLimitConfig: defaultRateLimitConfig(),
|
|
||||||
})
|
|
||||||
|
|
||||||
request := httptest.NewRequest(http.MethodGet, "/api", nil)
|
request := httptest.NewRequest(http.MethodGet, "/api", nil)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -344,20 +302,11 @@ func TestRouterWithEmptyStaticDir(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterWithAllFeaturesDisabled(t *testing.T) {
|
func TestRouterWithAllFeaturesDisabled(t *testing.T) {
|
||||||
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
|
cfg := createDefaultRouterConfig()
|
||||||
|
cfg.Debug = true
|
||||||
router := NewRouter(RouterConfig{
|
cfg.DisableCache = true
|
||||||
Debug: true,
|
cfg.DisableCompression = true
|
||||||
DisableCache: true,
|
router := createTestRouter(cfg)
|
||||||
DisableCompression: true,
|
|
||||||
APIHandler: apiHandler,
|
|
||||||
AuthHandler: authHandler,
|
|
||||||
PostHandler: postHandler,
|
|
||||||
VoteHandler: voteHandler,
|
|
||||||
UserHandler: userHandler,
|
|
||||||
AuthService: authService,
|
|
||||||
RateLimitConfig: defaultRateLimitConfig(),
|
|
||||||
})
|
|
||||||
|
|
||||||
request := httptest.NewRequest(http.MethodGet, "/api", nil)
|
request := httptest.NewRequest(http.MethodGet, "/api", nil)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -370,15 +319,9 @@ func TestRouterWithAllFeaturesDisabled(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterWithoutAPIHandler(t *testing.T) {
|
func TestRouterWithoutAPIHandler(t *testing.T) {
|
||||||
authHandler, postHandler, voteHandler, userHandler, _, authService := setupTestHandlers()
|
cfg := createDefaultRouterConfig()
|
||||||
router := NewRouter(RouterConfig{
|
cfg.APIHandler = nil
|
||||||
AuthHandler: authHandler,
|
router := createTestRouter(cfg)
|
||||||
PostHandler: postHandler,
|
|
||||||
VoteHandler: voteHandler,
|
|
||||||
UserHandler: userHandler,
|
|
||||||
AuthService: authService,
|
|
||||||
RateLimitConfig: defaultRateLimitConfig(),
|
|
||||||
})
|
|
||||||
|
|
||||||
request := httptest.NewRequest(http.MethodGet, "/api", nil)
|
request := httptest.NewRequest(http.MethodGet, "/api", nil)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -391,17 +334,7 @@ func TestRouterWithoutAPIHandler(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterWithoutPageHandler(t *testing.T) {
|
func TestRouterWithoutPageHandler(t *testing.T) {
|
||||||
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
|
router := createTestRouter(createDefaultRouterConfig())
|
||||||
|
|
||||||
router := NewRouter(RouterConfig{
|
|
||||||
APIHandler: apiHandler,
|
|
||||||
AuthHandler: authHandler,
|
|
||||||
PostHandler: postHandler,
|
|
||||||
VoteHandler: voteHandler,
|
|
||||||
UserHandler: userHandler,
|
|
||||||
AuthService: authService,
|
|
||||||
RateLimitConfig: defaultRateLimitConfig(),
|
|
||||||
})
|
|
||||||
|
|
||||||
request := httptest.NewRequest(http.MethodGet, "/", nil)
|
request := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -414,17 +347,7 @@ func TestRouterWithoutPageHandler(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSwaggerRoute(t *testing.T) {
|
func TestSwaggerRoute(t *testing.T) {
|
||||||
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
|
router := createTestRouter(createDefaultRouterConfig())
|
||||||
|
|
||||||
router := NewRouter(RouterConfig{
|
|
||||||
APIHandler: apiHandler,
|
|
||||||
AuthHandler: authHandler,
|
|
||||||
PostHandler: postHandler,
|
|
||||||
VoteHandler: voteHandler,
|
|
||||||
UserHandler: userHandler,
|
|
||||||
AuthService: authService,
|
|
||||||
RateLimitConfig: defaultRateLimitConfig(),
|
|
||||||
})
|
|
||||||
|
|
||||||
request := httptest.NewRequest(http.MethodGet, "/swagger/", nil)
|
request := httptest.NewRequest(http.MethodGet, "/swagger/", nil)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -437,18 +360,9 @@ func TestSwaggerRoute(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStaticFileRoute(t *testing.T) {
|
func TestStaticFileRoute(t *testing.T) {
|
||||||
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
|
cfg := createDefaultRouterConfig()
|
||||||
|
cfg.StaticDir = "../../internal/static/"
|
||||||
router := NewRouter(RouterConfig{
|
router := createTestRouter(cfg)
|
||||||
StaticDir: "../../internal/static/",
|
|
||||||
APIHandler: apiHandler,
|
|
||||||
AuthHandler: authHandler,
|
|
||||||
PostHandler: postHandler,
|
|
||||||
VoteHandler: voteHandler,
|
|
||||||
UserHandler: userHandler,
|
|
||||||
AuthService: authService,
|
|
||||||
RateLimitConfig: defaultRateLimitConfig(),
|
|
||||||
})
|
|
||||||
|
|
||||||
request := httptest.NewRequest(http.MethodGet, "/static/css/main.css", nil)
|
request := httptest.NewRequest(http.MethodGet, "/static/css/main.css", nil)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -461,17 +375,7 @@ func TestStaticFileRoute(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterConfiguration(t *testing.T) {
|
func TestRouterConfiguration(t *testing.T) {
|
||||||
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
|
router := createTestRouter(createDefaultRouterConfig())
|
||||||
|
|
||||||
router := NewRouter(RouterConfig{
|
|
||||||
APIHandler: apiHandler,
|
|
||||||
AuthHandler: authHandler,
|
|
||||||
PostHandler: postHandler,
|
|
||||||
VoteHandler: voteHandler,
|
|
||||||
UserHandler: userHandler,
|
|
||||||
AuthService: authService,
|
|
||||||
RateLimitConfig: defaultRateLimitConfig(),
|
|
||||||
})
|
|
||||||
|
|
||||||
if router == nil {
|
if router == nil {
|
||||||
t.Error("Router should not be nil")
|
t.Error("Router should not be nil")
|
||||||
@@ -487,29 +391,484 @@ func TestRouterConfiguration(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterMiddlewareIntegration(t *testing.T) {
|
func TestAllRoutesExist(t *testing.T) {
|
||||||
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
|
router := createTestRouter(createDefaultRouterConfig())
|
||||||
|
|
||||||
router := NewRouter(RouterConfig{
|
publicRoutes := []struct {
|
||||||
APIHandler: apiHandler,
|
method string
|
||||||
AuthHandler: authHandler,
|
path string
|
||||||
PostHandler: postHandler,
|
description string
|
||||||
VoteHandler: voteHandler,
|
}{
|
||||||
UserHandler: userHandler,
|
{http.MethodGet, "/api", "API info"},
|
||||||
AuthService: authService,
|
{http.MethodGet, "/health", "Health check"},
|
||||||
RateLimitConfig: defaultRateLimitConfig(),
|
{http.MethodGet, "/metrics", "Metrics"},
|
||||||
})
|
{http.MethodGet, "/robots.txt", "Robots.txt"},
|
||||||
|
{http.MethodGet, "/api/posts", "Get posts"},
|
||||||
if router == nil {
|
{http.MethodGet, "/api/posts/search", "Search posts"},
|
||||||
t.Error("Router should not be nil")
|
{http.MethodGet, "/api/posts/title", "Fetch title from URL"},
|
||||||
|
{http.MethodGet, "/api/posts/1", "Get post by ID"},
|
||||||
|
{http.MethodPost, "/api/auth/register", "Register"},
|
||||||
|
{http.MethodPost, "/api/auth/login", "Login"},
|
||||||
|
{http.MethodPost, "/api/auth/refresh", "Refresh token"},
|
||||||
|
{http.MethodGet, "/api/auth/confirm", "Confirm email"},
|
||||||
|
{http.MethodPost, "/api/auth/resend-verification", "Resend verification"},
|
||||||
|
{http.MethodPost, "/api/auth/forgot-password", "Forgot password"},
|
||||||
|
{http.MethodPost, "/api/auth/reset-password", "Reset password"},
|
||||||
|
{http.MethodPost, "/api/auth/account/confirm", "Confirm account deletion"},
|
||||||
}
|
}
|
||||||
|
|
||||||
request := httptest.NewRequest(http.MethodGet, "/api", nil)
|
protectedRoutes := []struct {
|
||||||
recorder := httptest.NewRecorder()
|
method string
|
||||||
|
path string
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{http.MethodGet, "/api/auth/me", "Get current user"},
|
||||||
|
{http.MethodPost, "/api/auth/logout", "Logout"},
|
||||||
|
{http.MethodPost, "/api/auth/revoke", "Revoke token"},
|
||||||
|
{http.MethodPost, "/api/auth/revoke-all", "Revoke all tokens"},
|
||||||
|
{http.MethodPut, "/api/auth/email", "Update email"},
|
||||||
|
{http.MethodPut, "/api/auth/username", "Update username"},
|
||||||
|
{http.MethodPut, "/api/auth/password", "Update password"},
|
||||||
|
{http.MethodDelete, "/api/auth/account", "Delete account"},
|
||||||
|
{http.MethodPost, "/api/posts", "Create post"},
|
||||||
|
{http.MethodPut, "/api/posts/1", "Update post"},
|
||||||
|
{http.MethodDelete, "/api/posts/1", "Delete post"},
|
||||||
|
{http.MethodPost, "/api/posts/1/vote", "Cast vote"},
|
||||||
|
{http.MethodDelete, "/api/posts/1/vote", "Remove vote"},
|
||||||
|
{http.MethodGet, "/api/posts/1/vote", "Get user vote"},
|
||||||
|
{http.MethodGet, "/api/posts/1/votes", "Get post votes"},
|
||||||
|
{http.MethodGet, "/api/users", "Get users"},
|
||||||
|
{http.MethodPost, "/api/users", "Create user"},
|
||||||
|
{http.MethodGet, "/api/users/1", "Get user by ID"},
|
||||||
|
{http.MethodGet, "/api/users/1/posts", "Get user posts"},
|
||||||
|
}
|
||||||
|
|
||||||
router.ServeHTTP(recorder, request)
|
for _, route := range publicRoutes {
|
||||||
|
t.Run(route.description+" "+route.method+" "+route.path, func(t *testing.T) {
|
||||||
|
invalidMethod := http.MethodPatch
|
||||||
|
switch route.method {
|
||||||
|
case http.MethodGet:
|
||||||
|
invalidMethod = http.MethodDelete
|
||||||
|
case http.MethodPost:
|
||||||
|
invalidMethod = http.MethodGet
|
||||||
|
}
|
||||||
|
request := httptest.NewRequest(invalidMethod, route.path, nil)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
if recorder.Code == 0 {
|
router.ServeHTTP(recorder, request)
|
||||||
t.Error("Router should return a status code")
|
|
||||||
|
routeExists := recorder.Code == http.StatusMethodNotAllowed
|
||||||
|
|
||||||
|
if !routeExists {
|
||||||
|
request = httptest.NewRequest(route.method, route.path, nil)
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
|
if recorder.Code == http.StatusNotFound && route.path != "/api/posts/1" && route.path != "/robots.txt" {
|
||||||
|
t.Errorf("Route %s %s should exist, got 404", route.method, route.path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, route := range protectedRoutes {
|
||||||
|
t.Run(route.description+" "+route.method+" "+route.path, func(t *testing.T) {
|
||||||
|
request := httptest.NewRequest(route.method, route.path, nil)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
|
if recorder.Code == http.StatusNotFound {
|
||||||
|
t.Errorf("Route %s %s should exist, got 404", route.method, route.path)
|
||||||
|
}
|
||||||
|
if recorder.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("Protected route %s %s should return 401 without auth, got %d", route.method, route.path, recorder.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouteParameters(t *testing.T) {
|
||||||
|
router := createTestRouter(createDefaultRouterConfig())
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
pathPattern string
|
||||||
|
testIDs []string
|
||||||
|
isProtected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Get post by ID",
|
||||||
|
method: http.MethodGet,
|
||||||
|
pathPattern: "/api/posts/{id}",
|
||||||
|
testIDs: []string{"1", "42", "999", "12345"},
|
||||||
|
isProtected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Update post by ID",
|
||||||
|
method: http.MethodPut,
|
||||||
|
pathPattern: "/api/posts/{id}",
|
||||||
|
testIDs: []string{"1", "42", "999"},
|
||||||
|
isProtected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Delete post by ID",
|
||||||
|
method: http.MethodDelete,
|
||||||
|
pathPattern: "/api/posts/{id}",
|
||||||
|
testIDs: []string{"1", "42", "999"},
|
||||||
|
isProtected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Get user by ID",
|
||||||
|
method: http.MethodGet,
|
||||||
|
pathPattern: "/api/users/{id}",
|
||||||
|
testIDs: []string{"1", "42", "999", "12345"},
|
||||||
|
isProtected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Get user posts by user ID",
|
||||||
|
method: http.MethodGet,
|
||||||
|
pathPattern: "/api/users/{id}/posts",
|
||||||
|
testIDs: []string{"1", "42", "999", "12345"},
|
||||||
|
isProtected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Cast vote for post ID",
|
||||||
|
method: http.MethodPost,
|
||||||
|
pathPattern: "/api/posts/{id}/vote",
|
||||||
|
testIDs: []string{"1", "42", "999"},
|
||||||
|
isProtected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Remove vote for post ID",
|
||||||
|
method: http.MethodDelete,
|
||||||
|
pathPattern: "/api/posts/{id}/vote",
|
||||||
|
testIDs: []string{"1", "42", "999"},
|
||||||
|
isProtected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Get user vote for post ID",
|
||||||
|
method: http.MethodGet,
|
||||||
|
pathPattern: "/api/posts/{id}/vote",
|
||||||
|
testIDs: []string{"1", "42", "999", "12345"},
|
||||||
|
isProtected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Get post votes by post ID",
|
||||||
|
method: http.MethodGet,
|
||||||
|
pathPattern: "/api/posts/{id}/votes",
|
||||||
|
testIDs: []string{"1", "42", "999", "12345"},
|
||||||
|
isProtected: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
for _, id := range tc.testIDs {
|
||||||
|
path := replaceID(tc.pathPattern, id)
|
||||||
|
t.Run("ID_"+id, func(t *testing.T) {
|
||||||
|
request := httptest.NewRequest(http.MethodPatch, path, nil)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
|
routeExists := recorder.Code == http.StatusMethodNotAllowed
|
||||||
|
|
||||||
|
request = httptest.NewRequest(tc.method, path, nil)
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
|
if !routeExists {
|
||||||
|
if recorder.Code == http.StatusNotFound {
|
||||||
|
t.Errorf("Route %s %s should exist with ID %s, got 404", tc.method, path, id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.isProtected {
|
||||||
|
if recorder.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("Protected route %s %s should return 401 without auth, got %d", tc.method, path, recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func replaceID(pattern, id string) string {
|
||||||
|
return strings.Replace(pattern, "{id}", id, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInvalidRouteParameters(t *testing.T) {
|
||||||
|
router := createTestRouter(createDefaultRouterConfig())
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
path string
|
||||||
|
expectedMin int
|
||||||
|
expectedMax int
|
||||||
|
isProtected bool
|
||||||
|
allow401 bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Non-numeric post ID",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/posts/abc",
|
||||||
|
expectedMin: http.StatusBadRequest,
|
||||||
|
expectedMax: http.StatusBadRequest,
|
||||||
|
isProtected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Negative post ID",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/posts/-1",
|
||||||
|
expectedMin: http.StatusBadRequest,
|
||||||
|
expectedMax: http.StatusBadRequest,
|
||||||
|
isProtected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Zero post ID",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/posts/0",
|
||||||
|
expectedMin: http.StatusBadRequest,
|
||||||
|
expectedMax: http.StatusNotFound,
|
||||||
|
isProtected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Post ID with special characters",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/posts/123@456",
|
||||||
|
expectedMin: http.StatusBadRequest,
|
||||||
|
expectedMax: http.StatusBadRequest,
|
||||||
|
isProtected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Post ID with encoded spaces",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/posts/12%2034",
|
||||||
|
expectedMin: http.StatusBadRequest,
|
||||||
|
expectedMax: http.StatusBadRequest,
|
||||||
|
isProtected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non-numeric user ID",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/users/xyz",
|
||||||
|
expectedMin: http.StatusBadRequest,
|
||||||
|
expectedMax: http.StatusUnauthorized,
|
||||||
|
isProtected: true,
|
||||||
|
allow401: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Negative user ID",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/users/-5",
|
||||||
|
expectedMin: http.StatusBadRequest,
|
||||||
|
expectedMax: http.StatusUnauthorized,
|
||||||
|
isProtected: true,
|
||||||
|
allow401: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non-numeric post ID in vote route",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/posts/invalid/vote",
|
||||||
|
expectedMin: http.StatusBadRequest,
|
||||||
|
expectedMax: http.StatusUnauthorized,
|
||||||
|
isProtected: true,
|
||||||
|
allow401: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Very large post ID",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/posts/999999999999",
|
||||||
|
expectedMin: http.StatusBadRequest,
|
||||||
|
expectedMax: http.StatusNotFound,
|
||||||
|
isProtected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
request := httptest.NewRequest(tc.method, tc.path, nil)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
|
if tc.isProtected && tc.allow401 {
|
||||||
|
if recorder.Code != http.StatusUnauthorized && (recorder.Code < tc.expectedMin || recorder.Code > tc.expectedMax) {
|
||||||
|
t.Errorf("Protected route %s %s with invalid parameter should return 401 or status between %d and %d, got %d", tc.method, tc.path, tc.expectedMin, tc.expectedMax, recorder.Code)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if recorder.Code < tc.expectedMin || recorder.Code > tc.expectedMax {
|
||||||
|
t.Errorf("Route %s %s should return status between %d and %d, got %d", tc.method, tc.path, tc.expectedMin, tc.expectedMax, recorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusNotFound && recorder.Code < 400 {
|
||||||
|
t.Errorf("Route %s %s with invalid parameter should return error status (4xx), got %d", tc.method, tc.path, recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryParameters(t *testing.T) {
|
||||||
|
router := createTestRouter(createDefaultRouterConfig())
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
path string
|
||||||
|
queryParams string
|
||||||
|
expectRoute bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Get posts with limit and offset",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/posts",
|
||||||
|
queryParams: "limit=10&offset=5",
|
||||||
|
expectRoute: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Get posts with only limit",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/posts",
|
||||||
|
queryParams: "limit=20",
|
||||||
|
expectRoute: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Get posts with only offset",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/posts",
|
||||||
|
queryParams: "offset=10",
|
||||||
|
expectRoute: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Search posts with query parameter",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/posts/search",
|
||||||
|
queryParams: "q=test",
|
||||||
|
expectRoute: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Search posts with query, limit, and offset",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/posts/search",
|
||||||
|
queryParams: "q=test&limit=15&offset=3",
|
||||||
|
expectRoute: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Fetch title with URL parameter",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/posts/title",
|
||||||
|
queryParams: "url=https://example.com",
|
||||||
|
expectRoute: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Confirm email with token parameter",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/auth/confirm",
|
||||||
|
queryParams: "token=abc123",
|
||||||
|
expectRoute: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Get posts with invalid limit",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/posts",
|
||||||
|
queryParams: "limit=abc",
|
||||||
|
expectRoute: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Get posts with negative limit",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/posts",
|
||||||
|
queryParams: "limit=-5",
|
||||||
|
expectRoute: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Get posts with negative offset",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/posts",
|
||||||
|
queryParams: "offset=-10",
|
||||||
|
expectRoute: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
fullPath := tc.path
|
||||||
|
if tc.queryParams != "" {
|
||||||
|
fullPath += "?" + tc.queryParams
|
||||||
|
}
|
||||||
|
|
||||||
|
request := httptest.NewRequest(tc.method, fullPath, nil)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
|
if tc.expectRoute {
|
||||||
|
if recorder.Code == http.StatusNotFound {
|
||||||
|
t.Errorf("Route %s %s should exist with query parameters, got 404", tc.method, fullPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouteConflicts(t *testing.T) {
|
||||||
|
router := createTestRouter(createDefaultRouterConfig())
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
path string
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "posts/search should not match posts/{id}",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/posts/search",
|
||||||
|
description: "search route should be matched, not treated as ID",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "posts/title should not match posts/{id}",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/posts/title",
|
||||||
|
description: "title route should be matched, not treated as ID",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "posts/{id} should work with numeric ID",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/posts/123",
|
||||||
|
description: "numeric ID should match {id} route",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
request := httptest.NewRequest(tc.method, tc.path, nil)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
|
switch tc.path {
|
||||||
|
case "/api/posts/search":
|
||||||
|
if recorder.Code == http.StatusNotFound {
|
||||||
|
t.Errorf("%s: Route %s %s should exist (not 404), got %d", tc.description, tc.method, tc.path, recorder.Code)
|
||||||
|
}
|
||||||
|
case "/api/posts/title":
|
||||||
|
if recorder.Code == http.StatusNotFound {
|
||||||
|
t.Errorf("%s: Route %s %s should exist (not 404), got %d", tc.description, tc.method, tc.path, recorder.Code)
|
||||||
|
}
|
||||||
|
case "/api/posts/123":
|
||||||
|
if recorder.Code == http.StatusNotFound {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if recorder.Code < 400 {
|
||||||
|
t.Errorf("%s: Route %s %s should return 4xx or 5xx, got %d", tc.description, tc.method, tc.path, recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,93 @@
|
|||||||
package services
|
package services
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/mail"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"goyco/internal/config"
|
"goyco/internal/config"
|
||||||
"goyco/internal/database"
|
"goyco/internal/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultTokenExpirationHours = 24
|
||||||
|
verificationTokenBytes = 32
|
||||||
|
deletionTokenExpirationHours = 24
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInvalidCredentials = errors.New("invalid credentials")
|
||||||
|
ErrInvalidToken = errors.New("invalid or expired token")
|
||||||
|
ErrUsernameTaken = errors.New("username already exists")
|
||||||
|
ErrEmailTaken = errors.New("email already exists")
|
||||||
|
ErrInvalidEmail = errors.New("invalid email address")
|
||||||
|
ErrPasswordTooShort = errors.New("password too short")
|
||||||
|
ErrEmailNotVerified = errors.New("email not verified")
|
||||||
|
ErrAccountLocked = errors.New("account is locked")
|
||||||
|
ErrInvalidVerificationToken = errors.New("invalid verification token")
|
||||||
|
ErrEmailSenderUnavailable = errors.New("email sender not configured")
|
||||||
|
ErrDeletionEmailFailed = errors.New("account deletion email failed")
|
||||||
|
ErrInvalidDeletionToken = errors.New("invalid account deletion token")
|
||||||
|
ErrUserNotFound = errors.New("user not found")
|
||||||
|
ErrDeletionRequestNotFound = errors.New("deletion request not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthResult struct {
|
||||||
|
AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
|
||||||
|
RefreshToken string `json:"refresh_token" example:"a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6q7r8s9t0"`
|
||||||
|
User *database.User `json:"user"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RegistrationResult struct {
|
||||||
|
User *database.User `json:"user"`
|
||||||
|
VerificationSent bool `json:"verification_sent"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeEmail(email string) (string, error) {
|
||||||
|
trimmed := strings.TrimSpace(email)
|
||||||
|
if trimmed == "" {
|
||||||
|
return "", fmt.Errorf("email is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed, err := mail.ParseAddress(trimmed)
|
||||||
|
if err != nil {
|
||||||
|
return "", ErrInvalidEmail
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.ToLower(parsed.Address), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateVerificationToken() (string, string, error) {
|
||||||
|
buf := make([]byte, verificationTokenBytes)
|
||||||
|
if _, err := rand.Read(buf); err != nil {
|
||||||
|
return "", "", fmt.Errorf("generate verification token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
token := hex.EncodeToString(buf)
|
||||||
|
hashed := HashVerificationToken(token)
|
||||||
|
return token, hashed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func HashVerificationToken(token string) string {
|
||||||
|
sum := sha256.Sum256([]byte(token))
|
||||||
|
return hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizeUser(user *database.User) *database.User {
|
||||||
|
if user == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
copy := *user
|
||||||
|
copy.Password = ""
|
||||||
|
copy.EmailVerificationToken = ""
|
||||||
|
return ©
|
||||||
|
}
|
||||||
|
|
||||||
type AuthFacade struct {
|
type AuthFacade struct {
|
||||||
registrationService *RegistrationService
|
registrationService *RegistrationService
|
||||||
passwordResetService *PasswordResetService
|
passwordResetService *PasswordResetService
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
package services
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
|
|
||||||
"goyco/internal/database"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrInvalidCredentials = errors.New("invalid credentials")
|
|
||||||
ErrInvalidToken = errors.New("invalid or expired token")
|
|
||||||
ErrUsernameTaken = errors.New("username already exists")
|
|
||||||
ErrEmailTaken = errors.New("email already exists")
|
|
||||||
ErrInvalidEmail = errors.New("invalid email address")
|
|
||||||
ErrPasswordTooShort = errors.New("password too short")
|
|
||||||
ErrEmailNotVerified = errors.New("email not verified")
|
|
||||||
ErrAccountLocked = errors.New("account is locked")
|
|
||||||
ErrInvalidVerificationToken = errors.New("invalid verification token")
|
|
||||||
ErrEmailSenderUnavailable = errors.New("email sender not configured")
|
|
||||||
ErrDeletionEmailFailed = errors.New("account deletion email failed")
|
|
||||||
ErrInvalidDeletionToken = errors.New("invalid account deletion token")
|
|
||||||
ErrUserNotFound = errors.New("user not found")
|
|
||||||
ErrDeletionRequestNotFound = errors.New("deletion request not found")
|
|
||||||
)
|
|
||||||
|
|
||||||
type AuthResult struct {
|
|
||||||
AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
|
|
||||||
RefreshToken string `json:"refresh_token" example:"a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6q7r8s9t0"`
|
|
||||||
User *database.User `json:"user"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type RegistrationResult struct {
|
|
||||||
User *database.User `json:"user"`
|
|
||||||
VerificationSent bool `json:"verification_sent"`
|
|
||||||
}
|
|
||||||
@@ -1,59 +0,0 @@
|
|||||||
package services
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
|
||||||
"net/mail"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"goyco/internal/database"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
defaultTokenExpirationHours = 24
|
|
||||||
verificationTokenBytes = 32
|
|
||||||
deletionTokenExpirationHours = 24
|
|
||||||
)
|
|
||||||
|
|
||||||
func normalizeEmail(email string) (string, error) {
|
|
||||||
trimmed := strings.TrimSpace(email)
|
|
||||||
if trimmed == "" {
|
|
||||||
return "", fmt.Errorf("email is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
parsed, err := mail.ParseAddress(trimmed)
|
|
||||||
if err != nil {
|
|
||||||
return "", ErrInvalidEmail
|
|
||||||
}
|
|
||||||
|
|
||||||
return strings.ToLower(parsed.Address), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateVerificationToken() (string, string, error) {
|
|
||||||
buf := make([]byte, verificationTokenBytes)
|
|
||||||
if _, err := rand.Read(buf); err != nil {
|
|
||||||
return "", "", fmt.Errorf("generate verification token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
token := hex.EncodeToString(buf)
|
|
||||||
hashed := HashVerificationToken(token)
|
|
||||||
return token, hashed, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func HashVerificationToken(token string) string {
|
|
||||||
sum := sha256.Sum256([]byte(token))
|
|
||||||
return hex.EncodeToString(sum[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
func sanitizeUser(user *database.User) *database.User {
|
|
||||||
if user == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
copy := *user
|
|
||||||
copy.Password = ""
|
|
||||||
copy.EmailVerificationToken = ""
|
|
||||||
return ©
|
|
||||||
}
|
|
||||||
@@ -32,10 +32,7 @@ func templateFuncMap() template.FuncMap {
|
|||||||
if start >= len(s) {
|
if start >= len(s) {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
end := start + length
|
end := min(start+length, len(s))
|
||||||
if end > len(s) {
|
|
||||||
end = len(s)
|
|
||||||
}
|
|
||||||
return s[start:end]
|
return s[start:end]
|
||||||
},
|
},
|
||||||
"upper": strings.ToUpper,
|
"upper": strings.ToUpper,
|
||||||
|
|||||||
@@ -2,21 +2,21 @@
|
|||||||
# helper script to setup a postgres database on deb based systems
|
# helper script to setup a postgres database on deb based systems
|
||||||
|
|
||||||
if [ "$EUID" -ne 0 ]; then
|
if [ "$EUID" -ne 0 ]; then
|
||||||
echo "Please run as root"
|
echo "Please run as root"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
read -s "Do you want to install PostgreSQL 18? [y/N] " INSTALL_PG
|
read -s "Do you want to install PostgreSQL 18? [y/N] " INSTALL_PG
|
||||||
if [ "$INSTALL_PG" != "y" ]; then
|
if [ "$INSTALL_PG" != "y" ]; then
|
||||||
echo "PostgreSQL 18 will not be installed"
|
echo "PostgreSQL 18 will not be installed"
|
||||||
exit 0
|
exit 0
|
||||||
fi
|
fi
|
||||||
|
|
||||||
read -s -p "Enter password for PostgreSQL user 'goyco': " GOYCO_PWD
|
read -s -p "Enter password for PostgreSQL user 'goyco': " GOYCO_PWD
|
||||||
echo
|
echo
|
||||||
|
|
||||||
apt-get update
|
apt-get update
|
||||||
apt-get install -y postgresql-18
|
apt-get install -y postgresql-18
|
||||||
|
|
||||||
systemctl enable --now postgresql
|
systemctl enable --now postgresql
|
||||||
|
|
||||||
@@ -44,5 +44,3 @@ GRANT ALL PRIVILEGES ON DATABASE goyco TO goyco;
|
|||||||
EOF
|
EOF
|
||||||
|
|
||||||
echo "PostgreSQL 18 installed, database 'goyco' and user 'goyco' set up."
|
echo "PostgreSQL 18 installed, database 'goyco' and user 'goyco' set up."
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user