To gitea and beyond, let's go(-yco)

This commit is contained in:
2025-11-10 19:12:09 +01:00
parent 8f6133392d
commit 71a031342b
245 changed files with 83994 additions and 0 deletions

56
cmd/goyco/cli.go Normal file
View File

@@ -0,0 +1,56 @@
package main
import (
"errors"
"flag"
"fmt"
"os"
"github.com/joho/godotenv"
"goyco/cmd/goyco/commands"
)
func loadDotEnv() {
if _, err := os.Stat(".env"); err == nil {
_ = godotenv.Load()
return
}
}
func newFlagSet(name string, usage func()) *flag.FlagSet {
fs := flag.NewFlagSet(name, flag.ContinueOnError)
fs.SetOutput(os.Stderr)
if usage != nil {
fs.Usage = usage
}
return fs
}
func parseCommand(fs *flag.FlagSet, args []string, context string) error {
if err := fs.Parse(args); err != nil {
if errors.Is(err, flag.ErrHelp) {
return commands.ErrHelpRequested
}
return fmt.Errorf("failed to parse %s command: %w", context, err)
}
return nil
}
func printRootUsage() {
fmt.Fprintf(os.Stderr, "Usage: %s <command> [<args>]\n", os.Args[0])
fmt.Fprintln(os.Stderr, "\nCommands:")
fmt.Fprintln(os.Stderr, " run start the web application in foreground")
fmt.Fprintln(os.Stderr, " start start the web application in background")
fmt.Fprintln(os.Stderr, " stop stop the daemon")
fmt.Fprintln(os.Stderr, " status check if the daemon is running")
fmt.Fprintln(os.Stderr, " migrate apply database migrations")
fmt.Fprintln(os.Stderr, " user manage users (create, update, delete, lock, list)")
fmt.Fprintln(os.Stderr, " post manage posts (delete, list, search)")
fmt.Fprintln(os.Stderr, " prune hard delete users and posts (posts, all)")
fmt.Fprintln(os.Stderr, " seed seed database with random data")
}
func printRunUsage() {
fmt.Fprintln(os.Stderr, "Usage: goyco run")
fmt.Fprintln(os.Stderr, "\nStart the web application in foreground.")
}

390
cmd/goyco/cli_test.go Normal file
View File

@@ -0,0 +1,390 @@
package main
import (
"errors"
"flag"
"os"
"strings"
"testing"
"gorm.io/gorm"
"goyco/cmd/goyco/commands"
"goyco/internal/config"
"goyco/internal/testutils"
)
func TestLoadDotEnv(t *testing.T) {
t.Run("no .env file", func(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Errorf("loadDotEnv panicked: %v", r)
}
}()
loadDotEnv()
})
}
func TestNewFlagSet(t *testing.T) {
tests := []struct {
name string
flagName string
usage func()
}{
{
name: "with usage function",
flagName: "test",
usage: func() { _, _ = os.Stderr.WriteString("test usage") },
},
{
name: "without usage function",
flagName: "test2",
usage: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fs := newFlagSet(tt.flagName, tt.usage)
if fs.Name() != tt.flagName {
t.Errorf("expected flag set name %q, got %q", tt.flagName, fs.Name())
}
if tt.usage != nil && fs.Usage == nil {
t.Error("expected usage function to be set")
}
})
}
}
func TestParseCommand(t *testing.T) {
tests := []struct {
name string
args []string
context string
expectError bool
expectHelp bool
}{
{
name: "valid arguments",
args: []string{"--help"},
context: "test",
expectError: true,
expectHelp: true,
},
{
name: "invalid flag",
args: []string{"--invalid-flag"},
context: "test",
expectError: true,
expectHelp: false,
},
{
name: "empty arguments",
args: []string{},
context: "test",
expectError: false,
expectHelp: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.SetOutput(os.Stderr)
err := parseCommand(fs, tt.args, tt.context)
if tt.expectError && err == nil {
t.Error("expected error but got none")
}
if !tt.expectError && err != nil {
t.Errorf("unexpected error: %v", err)
}
if tt.expectHelp && !errors.Is(err, commands.ErrHelpRequested) {
t.Error("expected help requested error")
}
})
}
}
func TestPrintRootUsage(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Errorf("printRootUsage panicked: %v", r)
}
}()
printRootUsage()
}
func TestPrintRunUsage(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Errorf("printRunUsage panicked: %v", r)
}
}()
printRunUsage()
}
func TestDispatchCommand(t *testing.T) {
t.Run("unknown command", func(t *testing.T) {
cfg := testutils.NewTestConfig()
err := dispatchCommand(cfg, "unknown", []string{})
if err == nil {
t.Error("expected error for unknown command")
}
expectedErr := "unknown command: unknown"
if err.Error() != expectedErr {
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
}
})
t.Run("help command", func(t *testing.T) {
cfg := testutils.NewTestConfig()
err := dispatchCommand(cfg, "help", []string{})
if err != nil {
t.Errorf("unexpected error for help command: %v", err)
}
})
t.Run("h command", func(t *testing.T) {
cfg := testutils.NewTestConfig()
err := dispatchCommand(cfg, "-h", []string{})
if err != nil {
t.Errorf("unexpected error for -h command: %v", err)
}
})
t.Run("--help command", func(t *testing.T) {
cfg := testutils.NewTestConfig()
err := dispatchCommand(cfg, "--help", []string{})
if err != nil {
t.Errorf("unexpected error for --help command: %v", err)
}
})
t.Run("post list with injected database", func(t *testing.T) {
cfg := testutils.NewTestConfig()
useInMemoryCommandsConnector(t)
err := dispatchCommand(cfg, "post", []string{"list"})
if err != nil {
t.Errorf("unexpected error for post list: %v", err)
}
})
}
func TestHandleRunCommand(t *testing.T) {
t.Run("help requested", func(t *testing.T) {
cfg := testutils.NewTestConfig()
err := handleRunCommand(cfg, []string{"--help"})
if err != nil {
t.Errorf("unexpected error for help: %v", err)
}
})
t.Run("unexpected arguments", func(t *testing.T) {
cfg := testutils.NewTestConfig()
err := handleRunCommand(cfg, []string{"extra", "args"})
if err == nil {
t.Error("expected error for unexpected arguments")
}
expectedErr := "unexpected arguments for run command"
if err.Error() != expectedErr {
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
}
})
}
func TestRun(t *testing.T) {
t.Run("no arguments", func(t *testing.T) {
err := run([]string{})
if err != nil {
t.Logf("Expected error in test environment: %v", err)
}
})
t.Run("help flag", func(t *testing.T) {
err := run([]string{"--help"})
if err == nil {
t.Error("expected config loading error in test environment")
}
})
t.Run("invalid flag", func(t *testing.T) {
err := run([]string{"--invalid-flag"})
if err == nil {
t.Error("expected error for invalid flag")
}
})
}
func TestRunE2E_CommandParsing(t *testing.T) {
setupTestEnv(t)
t.Run("help command succeeds", func(t *testing.T) {
err := run([]string{"help"})
if err != nil {
t.Errorf("Expected help command to succeed, got error: %v", err)
}
})
t.Run("unknown command fails with error", func(t *testing.T) {
err := run([]string{"unknown-command"})
if err == nil {
t.Error("Expected error for unknown command")
}
if err != nil && !strings.Contains(err.Error(), "unknown command") {
t.Errorf("Expected error about unknown command, got: %v", err)
}
})
t.Run("migrate command parses correctly", func(t *testing.T) {
err := run([]string{"migrate", "up"})
if err != nil && strings.Contains(err.Error(), "unknown command") {
t.Errorf("Expected migrate command to be recognized, got parsing error: %v", err)
}
})
t.Run("post command parses correctly", func(t *testing.T) {
useInMemoryCommandsConnector(t)
err := run([]string{"post", "list"})
if err != nil && strings.Contains(err.Error(), "unknown command") {
t.Errorf("Expected post command to be recognized, got parsing error: %v", err)
}
})
}
func TestRunE2E_ConfigurationLoading(t *testing.T) {
t.Run("missing required env vars fails gracefully", func(t *testing.T) {
originalDBPwd := os.Getenv("DB_PASSWORD")
originalSMTPHost := os.Getenv("SMTP_HOST")
originalSMTPFrom := os.Getenv("SMTP_FROM")
originalAdminEmail := os.Getenv("ADMIN_EMAIL")
originalJWTSecret := os.Getenv("JWT_SECRET")
defer func() {
if originalDBPwd != "" {
_ = os.Setenv("DB_PASSWORD", originalDBPwd)
}
if originalSMTPHost != "" {
_ = os.Setenv("SMTP_HOST", originalSMTPHost)
}
if originalSMTPFrom != "" {
_ = os.Setenv("SMTP_FROM", originalSMTPFrom)
}
if originalAdminEmail != "" {
_ = os.Setenv("ADMIN_EMAIL", originalAdminEmail)
}
if originalJWTSecret != "" {
_ = os.Setenv("JWT_SECRET", originalJWTSecret)
}
}()
_ = os.Unsetenv("DB_PASSWORD")
_ = os.Unsetenv("SMTP_HOST")
_ = os.Unsetenv("SMTP_FROM")
_ = os.Unsetenv("ADMIN_EMAIL")
_ = os.Unsetenv("JWT_SECRET")
err := run([]string{"help"})
if err == nil {
t.Error("Expected error when required env vars are missing")
}
if err != nil && !strings.Contains(err.Error(), "configuration") && !strings.Contains(err.Error(), "config") {
t.Logf("Got error (may be expected): %v", err)
}
})
t.Run("valid configuration loads successfully", func(t *testing.T) {
setupTestEnv(t)
err := run([]string{"help"})
if err != nil {
t.Errorf("Expected help command to succeed with valid config, got: %v", err)
}
})
}
func TestRunE2E_ArgumentParsing(t *testing.T) {
setupTestEnv(t)
t.Run("root help flag", func(t *testing.T) {
err := run([]string{"--help"})
if err != nil && !strings.Contains(err.Error(), "flag") {
t.Logf("Got error (may be expected in test env): %v", err)
}
})
t.Run("command with help flag", func(t *testing.T) {
err := run([]string{"migrate", "--help"})
if err != nil && strings.Contains(err.Error(), "unknown command") {
t.Errorf("Expected migrate command to be recognized, got: %v", err)
}
})
t.Run("command with invalid arguments", func(t *testing.T) {
err := run([]string{"run", "extra", "args"})
if err == nil {
t.Error("Expected error for unexpected arguments")
}
if err != nil && !strings.Contains(err.Error(), "unexpected arguments") {
t.Errorf("Expected error about unexpected arguments, got: %v", err)
}
})
}
func setupTestEnv(t *testing.T) {
t.Helper()
t.Setenv("DB_PASSWORD", "test-password")
t.Setenv("SMTP_HOST", "smtp.example.com")
t.Setenv("SMTP_FROM", "test@example.com")
t.Setenv("ADMIN_EMAIL", "admin@example.com")
t.Setenv("JWT_SECRET", "test-jwt-secret-key-that-is-long-enough-for-validation-purposes")
tmpDir := os.TempDir()
t.Setenv("LOG_DIR", tmpDir)
t.Setenv("PID_DIR", tmpDir)
}
func useInMemoryCommandsConnector(t *testing.T) {
t.Helper()
commands.SetDBConnector(func(_ *config.Config) (*gorm.DB, func() error, error) {
db := testutils.NewTestDB(t)
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("failed to access underlying sql.DB: %v", err)
}
cleanup := func() error {
return sqlDB.Close()
}
return db, cleanup, nil
})
t.Cleanup(func() {
commands.SetDBConnector(nil)
})
}

View File

@@ -0,0 +1,257 @@
package commands
import (
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"time"
)
type AuditLogger struct {
logFile string
logger *log.Logger
}
type AuditEvent struct {
Timestamp time.Time `json:"timestamp"`
Action string `json:"action"`
Resource string `json:"resource"`
ResourceID string `json:"resource_id,omitempty"`
Details string `json:"details,omitempty"`
User string `json:"user,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
UserAgent string `json:"user_agent,omitempty"`
Success bool `json:"success"`
Error string `json:"error,omitempty"`
Changes map[string]any `json:"changes,omitempty"`
}
func NewAuditLogger(logDir string) (*AuditLogger, error) {
if logDir == "" {
logDir = "/var/log"
}
if err := os.MkdirAll(logDir, 0755); err != nil {
return nil, fmt.Errorf("create audit log directory: %w", err)
}
logFile := filepath.Join(logDir, "goyco-audit.log")
file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return nil, fmt.Errorf("open audit log file: %w", err)
}
logger := log.New(file, "", 0)
return &AuditLogger{
logFile: logFile,
logger: logger,
}, nil
}
func (a *AuditLogger) LogEvent(event AuditEvent) {
if event.Timestamp.IsZero() {
event.Timestamp = time.Now()
}
jsonData, err := json.Marshal(event)
if err != nil {
a.logger.Printf("AUDIT: %s %s %s %s",
event.Timestamp.Format(time.RFC3339),
event.Action,
event.Resource,
event.Details)
return
}
a.logger.Printf("%s", string(jsonData))
}
func (a *AuditLogger) LogUserCreation(userID uint, username, email string, success bool, err error) {
event := AuditEvent{
Action: "user_create",
Resource: "user",
ResourceID: fmt.Sprintf("%d", userID),
Details: fmt.Sprintf("Created user: %s (%s)", username, email),
Success: success,
}
if err != nil {
event.Error = err.Error()
}
a.LogEvent(event)
}
func (a *AuditLogger) LogUserUpdate(userID uint, username string, changes map[string]any, success bool, err error) {
event := AuditEvent{
Action: "user_update",
Resource: "user",
ResourceID: fmt.Sprintf("%d", userID),
Details: fmt.Sprintf("Updated user: %s", username),
Changes: changes,
Success: success,
}
if err != nil {
event.Error = err.Error()
}
a.LogEvent(event)
}
func (a *AuditLogger) LogUserDeletion(userID uint, username string, deletePosts bool, success bool, err error) {
event := AuditEvent{
Action: "user_delete",
Resource: "user",
ResourceID: fmt.Sprintf("%d", userID),
Details: fmt.Sprintf("Deleted user: %s (delete_posts: %t)", username, deletePosts),
Success: success,
}
if err != nil {
event.Error = err.Error()
}
a.LogEvent(event)
}
func (a *AuditLogger) LogUserLock(userID uint, username string, locked bool, success bool, err error) {
action := "user_lock"
if !locked {
action = "user_unlock"
}
event := AuditEvent{
Action: action,
Resource: "user",
ResourceID: fmt.Sprintf("%d", userID),
Details: fmt.Sprintf("User %s: %s", username, action),
Success: success,
}
if err != nil {
event.Error = err.Error()
}
a.LogEvent(event)
}
func (a *AuditLogger) LogPostDeletion(postID uint, title string, success bool, err error) {
event := AuditEvent{
Action: "post_delete",
Resource: "post",
ResourceID: fmt.Sprintf("%d", postID),
Details: fmt.Sprintf("Deleted post: %s", title),
Success: success,
}
if err != nil {
event.Error = err.Error()
}
a.LogEvent(event)
}
func (a *AuditLogger) LogDataPruning(operation string, count int, success bool, err error) {
event := AuditEvent{
Action: "data_prune",
Resource: "data",
Details: fmt.Sprintf("Pruned %d records via %s", count, operation),
Success: success,
}
if err != nil {
event.Error = err.Error()
}
a.LogEvent(event)
}
func (a *AuditLogger) LogDatabaseMigration(operation string, success bool, err error) {
event := AuditEvent{
Action: "database_migrate",
Resource: "database",
Details: fmt.Sprintf("Database migration: %s", operation),
Success: success,
}
if err != nil {
event.Error = err.Error()
}
a.LogEvent(event)
}
func (a *AuditLogger) LogDatabaseSeeding(users, posts, votes int, success bool, err error) {
event := AuditEvent{
Action: "database_seed",
Resource: "database",
Details: fmt.Sprintf("Seeded database: %d users, %d posts, %d votes", users, posts, votes),
Success: success,
}
if err != nil {
event.Error = err.Error()
}
a.LogEvent(event)
}
func (a *AuditLogger) LogDaemonOperation(operation string, pid int, success bool, err error) {
event := AuditEvent{
Action: "daemon_" + operation,
Resource: "daemon",
Details: fmt.Sprintf("Daemon %s (PID: %d)", operation, pid),
Success: success,
}
if err != nil {
event.Error = err.Error()
}
a.LogEvent(event)
}
func (a *AuditLogger) LogSecurityEvent(eventType, details string, severity string) {
event := AuditEvent{
Action: "security_event",
Resource: "security",
Details: fmt.Sprintf("[%s] %s: %s", severity, eventType, details),
Success: true,
}
a.LogEvent(event)
}
func (a *AuditLogger) LogConfigurationChange(setting, oldValue, newValue string, success bool, err error) {
event := AuditEvent{
Action: "config_change",
Resource: "configuration",
Details: fmt.Sprintf("Changed %s from '%s' to '%s'", setting, oldValue, newValue),
Success: success,
}
if err != nil {
event.Error = err.Error()
}
a.LogEvent(event)
}
func (a *AuditLogger) GetLogFile() string {
return a.logFile
}
func (a *AuditLogger) Close() error {
a.LogEvent(AuditEvent{
Action: "audit_logger_close",
Resource: "audit",
Details: "Audit logger closed",
Success: true,
})
return nil
}

View File

@@ -0,0 +1,95 @@
package commands
import (
"errors"
"flag"
"fmt"
"os"
"sync"
"gorm.io/gorm"
"goyco/internal/config"
"goyco/internal/database"
)
var ErrHelpRequested = errors.New("help requested")
type DBConnector func(cfg *config.Config) (*gorm.DB, func() error, error)
var (
dbConnectorMu sync.RWMutex
currentDBConnector = defaultDBConnector
)
func defaultDBConnector(cfg *config.Config) (*gorm.DB, func() error, error) {
db, err := database.Connect(cfg)
if err != nil {
return nil, nil, err
}
return db, func() error { return database.Close(db) }, nil
}
func SetDBConnector(connector DBConnector) {
dbConnectorMu.Lock()
defer dbConnectorMu.Unlock()
if connector == nil {
currentDBConnector = defaultDBConnector
return
}
currentDBConnector = connector
}
func getDBConnector() DBConnector {
dbConnectorMu.RLock()
defer dbConnectorMu.RUnlock()
return currentDBConnector
}
func newFlagSet(name string, usage func()) *flag.FlagSet {
fs := flag.NewFlagSet(name, flag.ContinueOnError)
fs.SetOutput(os.Stderr)
if usage != nil {
fs.Usage = usage
}
return fs
}
func parseCommand(fs *flag.FlagSet, args []string, context string) error {
if err := fs.Parse(args); err != nil {
if errors.Is(err, flag.ErrHelp) {
return ErrHelpRequested
}
return fmt.Errorf("failed to parse %s command: %w", context, err)
}
return nil
}
func withDatabase(cfg *config.Config, fn func(db *gorm.DB) error) error {
connector := getDBConnector()
db, cleanup, err := connector(cfg)
if err != nil {
return fmt.Errorf("connect to database: %w", err)
}
if cleanup != nil {
defer func() {
if err := cleanup(); err != nil {
fmt.Printf("Warning: closing database: %v\n", err)
}
}()
}
return fn(db)
}
func truncate(in string, max int) string {
if len(in) <= max {
return in
}
if max <= 3 {
return in[:max]
}
return in[:max-3] + "..."
}

View File

@@ -0,0 +1,219 @@
package commands
import (
"errors"
"flag"
"os"
"testing"
"gorm.io/gorm"
"goyco/internal/config"
"goyco/internal/testutils"
)
func TestNewFlagSet(t *testing.T) {
tests := []struct {
name string
flagName string
usage func()
}{
{
name: "with usage function",
flagName: "test",
usage: func() { _, _ = os.Stderr.WriteString("test usage") },
},
{
name: "without usage function",
flagName: "test2",
usage: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fs := newFlagSet(tt.flagName, tt.usage)
if fs.Name() != tt.flagName {
t.Errorf("expected flag set name %q, got %q", tt.flagName, fs.Name())
}
if tt.usage != nil && fs.Usage == nil {
t.Error("expected usage function to be set")
}
})
}
}
func TestParseCommand(t *testing.T) {
tests := []struct {
name string
args []string
context string
expectError bool
expectHelp bool
}{
{
name: "valid arguments",
args: []string{"--help"},
context: "test",
expectError: true,
expectHelp: true,
},
{
name: "invalid flag",
args: []string{"--invalid-flag"},
context: "test",
expectError: true,
expectHelp: false,
},
{
name: "empty arguments",
args: []string{},
context: "test",
expectError: false,
expectHelp: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.SetOutput(os.Stderr)
err := parseCommand(fs, tt.args, tt.context)
if tt.expectError && err == nil {
t.Error("expected error but got none")
}
if !tt.expectError && err != nil {
t.Errorf("unexpected error: %v", err)
}
if tt.expectHelp && !errors.Is(err, ErrHelpRequested) {
t.Error("expected help requested error")
}
})
}
}
func TestTruncate(t *testing.T) {
tests := []struct {
name string
input string
max int
expected string
}{
{
name: "string shorter than max",
input: "short",
max: 10,
expected: "short",
},
{
name: "string equal to max",
input: "exactly",
max: 7,
expected: "exactly",
},
{
name: "string longer than max",
input: "this is a very long string",
max: 10,
expected: "this is...",
},
{
name: "string longer than max with small max",
input: "hello",
max: 3,
expected: "hel",
},
{
name: "string longer than max with very small max",
input: "hello",
max: 1,
expected: "h",
},
{
name: "empty string",
input: "",
max: 5,
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := truncate(tt.input, tt.max)
if result != tt.expected {
t.Errorf("expected %q, got %q", tt.expected, result)
}
})
}
}
func TestWithDatabase(t *testing.T) {
cfg := testutils.NewTestConfig()
t.Run("custom connector success", func(t *testing.T) {
setInMemoryDBConnector(t)
var called bool
err := withDatabase(cfg, func(db *gorm.DB) error {
called = true
if db == nil {
t.Fatal("expected non-nil database")
}
return nil
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !called {
t.Fatal("expected database function to be called")
}
})
t.Run("default connector failure", func(t *testing.T) {
SetDBConnector(nil)
var called bool
err := withDatabase(cfg, func(db *gorm.DB) error {
called = true
return nil
})
if err == nil {
t.Error("expected database connection error in test environment")
}
if called {
t.Error("expected database function not to be called when connection fails")
}
})
}
func setInMemoryDBConnector(t *testing.T) {
t.Helper()
SetDBConnector(func(_ *config.Config) (*gorm.DB, func() error, error) {
db := testutils.NewTestDB(t)
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("failed to access underlying sql.DB: %v", err)
}
cleanup := func() error {
return sqlDB.Close()
}
return db, cleanup, nil
})
t.Cleanup(func() {
SetDBConnector(nil)
})
}

View File

@@ -0,0 +1,348 @@
package commands
import (
"fmt"
"net"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"goyco/internal/config"
)
type ConfigValidator struct {
auditLogger *AuditLogger
}
func NewConfigValidator(auditLogger *AuditLogger) *ConfigValidator {
return &ConfigValidator{
auditLogger: auditLogger,
}
}
func (v *ConfigValidator) ValidateConfiguration(cfg *config.Config) error {
var errors []string
if err := v.validateDatabaseConfig(cfg); err != nil {
errors = append(errors, fmt.Sprintf("Database: %v", err))
}
if err := v.validateSMTPConfig(cfg); err != nil {
errors = append(errors, fmt.Sprintf("SMTP: %v", err))
}
if err := v.validateServerConfig(cfg); err != nil {
errors = append(errors, fmt.Sprintf("Server: %v", err))
}
if err := v.validateSecurityConfig(cfg); err != nil {
errors = append(errors, fmt.Sprintf("Security: %v", err))
}
if err := v.validateFilePaths(cfg); err != nil {
errors = append(errors, fmt.Sprintf("File paths: %v", err))
}
if len(errors) > 0 {
return fmt.Errorf("configuration validation failed:\n- %s", strings.Join(errors, "\n- "))
}
if v.auditLogger != nil {
v.auditLogger.LogConfigurationChange("validation", "invalid", "valid", true, nil)
}
return nil
}
func (v *ConfigValidator) validateDatabaseConfig(cfg *config.Config) error {
if cfg.Database.Host == "" {
return fmt.Errorf("DB_HOST is required")
}
port, err := strconv.Atoi(cfg.Database.Port)
if err != nil {
return fmt.Errorf("DB_PORT must be a valid integer")
}
if port <= 0 || port > 65535 {
return fmt.Errorf("DB_PORT must be between 1 and 65535")
}
if cfg.Database.Name == "" {
return fmt.Errorf("DB_NAME is required")
}
if cfg.Database.User == "" {
return fmt.Errorf("DB_USER is required")
}
if cfg.Database.Password == "" {
return fmt.Errorf("DB_PASSWORD is required")
}
if !v.isValidHost(cfg.Database.Host) {
return fmt.Errorf("DB_HOST has invalid format")
}
return nil
}
func (v *ConfigValidator) validateSMTPConfig(cfg *config.Config) error {
if cfg.SMTP.Host == "" {
return fmt.Errorf("SMTP_HOST is required")
}
if cfg.SMTP.Port <= 0 || cfg.SMTP.Port > 65535 {
return fmt.Errorf("SMTP_PORT must be between 1 and 65535")
}
if cfg.SMTP.From == "" {
return fmt.Errorf("SMTP_FROM is required")
}
if !v.isValidEmail(cfg.SMTP.From) {
return fmt.Errorf("SMTP_FROM has invalid email format")
}
if cfg.App.AdminEmail == "" {
return fmt.Errorf("ADMIN_EMAIL is required")
}
if !v.isValidEmail(cfg.App.AdminEmail) {
return fmt.Errorf("ADMIN_EMAIL has invalid email format")
}
if !v.isValidHost(cfg.SMTP.Host) {
return fmt.Errorf("SMTP_HOST has invalid format")
}
return nil
}
func (v *ConfigValidator) validateServerConfig(cfg *config.Config) error {
serverPort, err := strconv.Atoi(cfg.Server.Port)
if err != nil {
return fmt.Errorf("SERVER_PORT must be a valid integer")
}
if serverPort <= 0 || serverPort > 65535 {
return fmt.Errorf("SERVER_PORT must be between 1 and 65535")
}
if cfg.App.BaseURL != "" {
if !v.isValidURL(cfg.App.BaseURL) {
return fmt.Errorf("BASE_URL has invalid format")
}
}
if cfg.Server.EnableTLS {
if cfg.Server.TLSCertFile == "" {
return fmt.Errorf("SERVER_TLS_CERT_FILE is required when TLS is enabled")
}
if cfg.Server.TLSKeyFile == "" {
return fmt.Errorf("SERVER_TLS_KEY_FILE is required when TLS is enabled")
}
}
return nil
}
func (v *ConfigValidator) validateSecurityConfig(cfg *config.Config) error {
if cfg.JWT.Secret == "" {
return fmt.Errorf("JWT_SECRET is required")
}
if len(cfg.JWT.Secret) < 32 {
return fmt.Errorf("JWT_SECRET must be at least 32 characters for security")
}
weakSecrets := []string{
"your-secret-key", "secret", "jwt-secret", "my-secret",
"change-me", "default-secret", "123456", "password",
"admin", "test", "development", "production", "staging",
}
lowerSecret := strings.ToLower(cfg.JWT.Secret)
for _, weak := range weakSecrets {
if lowerSecret == weak {
return fmt.Errorf("JWT_SECRET cannot be a common weak value: %s", weak)
}
}
return nil
}
func (v *ConfigValidator) validateFilePaths(cfg *config.Config) error {
if cfg.LogDir != "" {
if err := v.validateDirectory(cfg.LogDir, "LOG_DIR"); err != nil {
return err
}
}
if cfg.PIDDir != "" {
if err := v.validateDirectory(cfg.PIDDir, "PID_DIR"); err != nil {
return err
}
}
if cfg.Server.EnableTLS {
if err := v.validateFile(cfg.Server.TLSCertFile, "SERVER_TLS_CERT_FILE"); err != nil {
return err
}
if err := v.validateFile(cfg.Server.TLSKeyFile, "SERVER_TLS_KEY_FILE"); err != nil {
return err
}
}
return nil
}
func (v *ConfigValidator) validateDirectory(path, name string) error {
if _, err := os.Stat(path); os.IsNotExist(err) {
if err := os.MkdirAll(path, 0755); err != nil {
return fmt.Errorf("%s directory does not exist and cannot be created: %v", name, err)
}
}
if info, err := os.Stat(path); err == nil {
if !info.IsDir() {
return fmt.Errorf("%s path exists but is not a directory", name)
}
}
if err := v.checkWritePermission(path); err != nil {
return fmt.Errorf("%s directory is not writable: %v", name, err)
}
return nil
}
func (v *ConfigValidator) validateFile(path, name string) error {
if _, err := os.Stat(path); os.IsNotExist(err) {
return fmt.Errorf("%s file does not exist: %s", name, path)
}
if info, err := os.Stat(path); err == nil {
if info.IsDir() {
return fmt.Errorf("%s path exists but is a directory, not a file", name)
}
}
if err := v.checkReadPermission(path); err != nil {
return fmt.Errorf("%s file is not readable: %v", name, err)
}
return nil
}
func (v *ConfigValidator) checkWritePermission(path string) error {
testFile := filepath.Join(path, ".goyco_test_write")
file, err := os.Create(testFile)
if err != nil {
return err
}
_ = file.Close()
_ = os.Remove(testFile)
return nil
}
func (v *ConfigValidator) checkReadPermission(path string) error {
file, err := os.Open(path)
if err != nil {
return err
}
_ = file.Close()
return nil
}
func (v *ConfigValidator) isValidHost(host string) bool {
if net.ParseIP(host) != nil {
return true
}
if v.isValidHostname(host) {
return true
}
return false
}
func (v *ConfigValidator) isValidHostname(hostname string) bool {
if len(hostname) == 0 || len(hostname) > 253 {
return false
}
hostnameRegex := regexp.MustCompile(`^[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?)*$`)
return hostnameRegex.MatchString(hostname)
}
func (v *ConfigValidator) isValidEmail(email string) bool {
emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
return emailRegex.MatchString(email)
}
func (v *ConfigValidator) isValidURL(url string) bool {
urlRegex := regexp.MustCompile(`^https?://[a-zA-Z0-9.-]+(:[0-9]+)?(/.*)?$`)
return urlRegex.MatchString(url)
}
func (v *ConfigValidator) ValidateEnvironmentVariables() error {
requiredVars := []string{
"DB_HOST", "DB_PORT", "DB_NAME", "DB_USER", "DB_PASSWORD",
"SMTP_HOST", "SMTP_PORT", "SMTP_FROM", "ADMIN_EMAIL", "JWT_SECRET",
}
var missingVars []string
for _, varName := range requiredVars {
if os.Getenv(varName) == "" {
missingVars = append(missingVars, varName)
}
}
if len(missingVars) > 0 {
return fmt.Errorf("missing required environment variables: %s", strings.Join(missingVars, ", "))
}
return nil
}
func (v *ConfigValidator) ValidatePort(portStr, name string) (int, error) {
port, err := strconv.Atoi(portStr)
if err != nil {
return 0, fmt.Errorf("%s must be a valid integer", name)
}
if port <= 0 || port > 65535 {
return 0, fmt.Errorf("%s must be between 1 and 65535", name)
}
return port, nil
}
func (v *ConfigValidator) ValidateEmail(email, name string) error {
if email == "" {
return fmt.Errorf("%s is required", name)
}
if !v.isValidEmail(email) {
return fmt.Errorf("%s has invalid email format", name)
}
return nil
}
func (v *ConfigValidator) ValidatePassword(password, name string) error {
if password == "" {
return fmt.Errorf("%s is required", name)
}
if len(password) < 8 {
return fmt.Errorf("%s must be at least 8 characters", name)
}
if len(password) > 128 {
return fmt.Errorf("%s must be 128 characters or less", name)
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,346 @@
package commands
import (
"errors"
"fmt"
"log"
"os"
"os/signal"
"path/filepath"
"strconv"
"sync"
"syscall"
"time"
"goyco/internal/config"
)
func HandleStartCommand(cfg *config.Config, args []string) error {
fs := newFlagSet("start", printStartUsage)
if err := parseCommand(fs, args, "start"); err != nil {
if errors.Is(err, ErrHelpRequested) {
return nil
}
return err
}
if fs.NArg() > 0 {
printStartUsage()
return errors.New("unexpected arguments for start command")
}
return runDaemon(cfg)
}
func HandleStopCommand(cfg *config.Config, args []string) error {
fs := newFlagSet("stop", printStopUsage)
if err := parseCommand(fs, args, "stop"); err != nil {
if errors.Is(err, ErrHelpRequested) {
return nil
}
return err
}
if fs.NArg() > 0 {
printStopUsage()
return errors.New("unexpected arguments for stop command")
}
return stopDaemon(cfg)
}
func HandleStatusCommand(cfg *config.Config, name string, args []string) error {
fs := newFlagSet(name, printStatusUsage)
if err := parseCommand(fs, args, name); err != nil {
if errors.Is(err, ErrHelpRequested) {
return nil
}
return err
}
if fs.NArg() > 0 {
printStatusUsage()
return errors.New("unexpected arguments for status command")
}
return runStatusCommand(cfg)
}
func printStartUsage() {
fmt.Fprintln(os.Stderr, "Usage: goyco start")
fmt.Fprintln(os.Stderr, "\nStart the web application in background.")
}
func printStopUsage() {
fmt.Fprintln(os.Stderr, "Usage: goyco stop")
fmt.Fprintln(os.Stderr, "\nStop the running daemon.")
}
func printStatusUsage() {
fmt.Fprintln(os.Stderr, "Usage: goyco status")
fmt.Fprintln(os.Stderr, "\nCheck if the daemon is running.")
}
func runStatusCommand(cfg *config.Config) error {
pidDir := cfg.PIDDir
pidFile := filepath.Join(pidDir, "goyco.pid")
if !isDaemonRunning(pidFile) {
fmt.Println("Goyco is not running")
return nil
}
data, err := os.ReadFile(pidFile)
if err != nil {
fmt.Printf("Goyco is running (PID file exists but cannot be read: %v)\n", err)
return nil
}
pid, err := strconv.Atoi(string(data))
if err != nil {
fmt.Printf("Goyco is running (PID file exists but contains invalid PID: %v)\n", err)
return nil
}
fmt.Printf("Goyco is running (PID %d)\n", pid)
return nil
}
func stopDaemon(cfg *config.Config) error {
pidDir := cfg.PIDDir
pidFile := filepath.Join(pidDir, "goyco.pid")
if !isDaemonRunning(pidFile) {
return fmt.Errorf("daemon is not running")
}
data, err := os.ReadFile(pidFile)
if err != nil {
return fmt.Errorf("read PID file: %w", err)
}
pid, err := strconv.Atoi(string(data))
if err != nil {
return fmt.Errorf("parse PID: %w", err)
}
process, err := os.FindProcess(pid)
if err != nil {
return fmt.Errorf("find process: %w", err)
}
if err := process.Signal(syscall.SIGTERM); err != nil {
return fmt.Errorf("send SIGTERM: %w", err)
}
time.Sleep(2 * time.Second)
if isDaemonRunning(pidFile) {
if err := process.Signal(syscall.SIGKILL); err != nil {
return fmt.Errorf("send SIGKILL: %w", err)
}
}
_ = os.Remove(pidFile)
fmt.Printf("Goyco stopped (PID %d)\n", pid)
return nil
}
func runDaemon(cfg *config.Config) error {
logDir := cfg.LogDir
if logDir == "" {
logDir = "/var/log"
}
if err := os.MkdirAll(logDir, 0o755); err != nil {
return fmt.Errorf("create log directory: %w", err)
}
pidDir := cfg.PIDDir
if pidDir == "" {
pidDir = "/run"
}
if err := os.MkdirAll(pidDir, 0o755); err != nil {
return fmt.Errorf("create PID directory: %w", err)
}
pidFile := filepath.Join(pidDir, "goyco.pid")
logFile := filepath.Join(logDir, "goyco.log")
if isDaemonRunning(pidFile) {
return fmt.Errorf("daemon is already running (PID file exists: %s)", pidFile)
}
daemonizeFnMu.Lock()
fn := daemonizeFn
daemonizeFnMu.Unlock()
pid, err := fn()
if err != nil {
return fmt.Errorf("failed to daemonize: %w", err)
}
if pid > 0 {
if err := writePIDFile(pidFile, pid); err != nil {
return fmt.Errorf("cannot write PID file: %w", err)
}
fmt.Printf("Goyco started with PID %d\n", pid)
fmt.Printf("PID file: %s\n", pidFile)
fmt.Printf("Log file: %s\n", logFile)
return nil
}
return runDaemonProcess(cfg, logDir, pidFile)
}
func daemonizeImpl() (int, error) {
args := make([]string, len(os.Args))
copy(args, os.Args)
args = append(args, "--daemon")
pid, err := syscall.ForkExec(os.Args[0], args, &syscall.ProcAttr{
Files: []uintptr{0, 1, 2},
Env: os.Environ(),
})
if err != nil {
return 0, err
}
return pid, nil
}
func isDaemonRunning(pidFile string) bool {
if _, err := os.Stat(pidFile); os.IsNotExist(err) {
return false
}
data, err := os.ReadFile(pidFile)
if err != nil {
return false
}
pid, err := strconv.Atoi(string(data))
if err != nil {
return false
}
process, err := os.FindProcess(pid)
if err != nil {
return false
}
err = process.Signal(syscall.Signal(0))
return err == nil
}
func writePIDFile(pidFile string, pid int) error {
return os.WriteFile(pidFile, []byte(strconv.Itoa(pid)), 0o644)
}
func runDaemonProcess(cfg *config.Config, logDir, pidFile string) error {
daemonizeFnMu.Lock()
setupLogFn := setupLoggingFn
daemonizeFnMu.Unlock()
if err := setupLogFn(cfg, logDir); err != nil {
return fmt.Errorf("setup daemon logging: %w", err)
}
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT)
serverErr := make(chan error, 1)
go func() {
serverErr <- runServer(cfg, true)
}()
select {
case sig := <-sigChan:
log.Printf("Received signal %v, shutting down gracefully...", sig)
if err := os.Remove(pidFile); err != nil {
log.Printf("Error removing PID file: %v", err)
}
return nil
case err := <-serverErr:
if removeErr := os.Remove(pidFile); removeErr != nil {
log.Printf("Error removing PID file: %v", removeErr)
}
return err
}
}
func setupDaemonLoggingImpl(cfg *config.Config, logDir string) error {
logFile := filepath.Join(logDir, "goyco.log")
logFileHandle, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
if err != nil {
return fmt.Errorf("open log file: %w", err)
}
log.SetOutput(logFileHandle)
log.SetFlags(log.LstdFlags)
log.Printf("Starting goyco in daemon mode")
return nil
}
func SetupDaemonLogging(cfg *config.Config, logDir string) error {
daemonizeFnMu.Lock()
setupLogFn := setupLoggingFn
daemonizeFnMu.Unlock()
return setupLogFn(cfg, logDir)
}
var runServer func(cfg *config.Config, daemon bool) error
func SetRunServer(fn func(cfg *config.Config, daemon bool) error) {
runServer = fn
}
type daemonizeFunc func() (int, error)
var (
daemonizeFnMu sync.Mutex
daemonizeFn daemonizeFunc = daemonizeImpl
setupLoggingFn func(cfg *config.Config, logDir string) error = setupDaemonLoggingImpl
)
func SetDaemonize(fn daemonizeFunc) {
daemonizeFnMu.Lock()
defer daemonizeFnMu.Unlock()
if fn == nil {
daemonizeFn = daemonizeImpl
} else {
daemonizeFn = fn
}
}
func SetSetupDaemonLogging(fn func(cfg *config.Config, logDir string) error) {
daemonizeFnMu.Lock()
defer daemonizeFnMu.Unlock()
if fn == nil {
setupLoggingFn = setupDaemonLoggingImpl
} else {
setupLoggingFn = fn
}
}
func RunDaemonProcessDirect(_ []string) error {
cfg, err := config.Load()
if err != nil {
return fmt.Errorf("load configuration: %w", err)
}
logDir := cfg.LogDir
if logDir == "" {
return fmt.Errorf("LOG_DIR environment variable is required for daemon mode")
}
pidDir := cfg.PIDDir
if err := os.MkdirAll(pidDir, 0o755); err != nil {
return fmt.Errorf("create PID directory: %w", err)
}
pidFile := filepath.Join(pidDir, "goyco.pid")
return runDaemonProcess(cfg, logDir, pidFile)
}

View File

@@ -0,0 +1,306 @@
package commands
import (
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"goyco/internal/config"
"goyco/internal/testutils"
)
func TestHandleStartCommand(t *testing.T) {
cfg := testutils.NewTestConfig()
t.Run("help requested", func(t *testing.T) {
err := HandleStartCommand(cfg, []string{"--help"})
if err != nil {
t.Errorf("unexpected error for help: %v", err)
}
})
t.Run("unexpected arguments", func(t *testing.T) {
err := HandleStartCommand(cfg, []string{"extra", "args"})
if err == nil {
t.Error("expected error for unexpected arguments")
}
expectedErr := "unexpected arguments for start command"
if err.Error() != expectedErr {
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
}
})
}
func TestHandleStopCommand(t *testing.T) {
cfg := testutils.NewTestConfig()
t.Run("help requested", func(t *testing.T) {
err := HandleStopCommand(cfg, []string{"--help"})
if err != nil {
t.Errorf("unexpected error for help: %v", err)
}
})
t.Run("unexpected arguments", func(t *testing.T) {
err := HandleStopCommand(cfg, []string{"extra", "args"})
if err == nil {
t.Error("expected error for unexpected arguments")
}
expectedErr := "unexpected arguments for stop command"
if err.Error() != expectedErr {
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
}
})
}
func TestHandleStatusCommand(t *testing.T) {
cfg := testutils.NewTestConfig()
t.Run("help requested", func(t *testing.T) {
err := HandleStatusCommand(cfg, "status", []string{"--help"})
if err != nil {
t.Errorf("unexpected error for help: %v", err)
}
})
t.Run("unexpected arguments", func(t *testing.T) {
err := HandleStatusCommand(cfg, "status", []string{"extra", "args"})
if err == nil {
t.Error("expected error for unexpected arguments")
}
expectedErr := "unexpected arguments for status command"
if err.Error() != expectedErr {
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
}
})
}
func TestRunStatusCommand(t *testing.T) {
cfg := testutils.NewTestConfig()
t.Run("daemon not running", func(t *testing.T) {
tempDir := t.TempDir()
cfg.PIDDir = tempDir
err := runStatusCommand(cfg)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("daemon running with valid PID", func(t *testing.T) {
tempDir := t.TempDir()
cfg.PIDDir = tempDir
pidFile := filepath.Join(tempDir, "goyco.pid")
currentPID := os.Getpid()
err := os.WriteFile(pidFile, []byte(strconv.Itoa(currentPID)), 0644)
if err != nil {
t.Fatalf("Failed to create PID file: %v", err)
}
err = runStatusCommand(cfg)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("daemon running with invalid PID file", func(t *testing.T) {
tempDir := t.TempDir()
cfg.PIDDir = tempDir
pidFile := filepath.Join(tempDir, "goyco.pid")
err := os.WriteFile(pidFile, []byte("invalid-pid"), 0644)
if err != nil {
t.Fatalf("Failed to create PID file: %v", err)
}
err = runStatusCommand(cfg)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
}
func TestIsDaemonRunning(t *testing.T) {
t.Run("PID file does not exist", func(t *testing.T) {
pidFile := "/non/existent/pid/file"
result := isDaemonRunning(pidFile)
if result {
t.Error("expected false for non-existent PID file")
}
})
t.Run("PID file exists but contains invalid PID", func(t *testing.T) {
tempDir := t.TempDir()
pidFile := filepath.Join(tempDir, "goyco.pid")
err := os.WriteFile(pidFile, []byte("invalid-pid"), 0644)
if err != nil {
t.Fatalf("Failed to create PID file: %v", err)
}
result := isDaemonRunning(pidFile)
if result {
t.Error("expected false for invalid PID")
}
})
t.Run("PID file exists with valid PID", func(t *testing.T) {
tempDir := t.TempDir()
pidFile := filepath.Join(tempDir, "goyco.pid")
currentPID := os.Getpid()
err := os.WriteFile(pidFile, []byte(strconv.Itoa(currentPID)), 0644)
if err != nil {
t.Fatalf("Failed to create PID file: %v", err)
}
result := isDaemonRunning(pidFile)
if !result {
t.Error("expected true for valid PID")
}
})
}
func TestWritePIDFile(t *testing.T) {
t.Run("successful write", func(t *testing.T) {
tempDir := t.TempDir()
pidFile := filepath.Join(tempDir, "goyco.pid")
pid := 12345
err := writePIDFile(pidFile, pid)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
content, err := os.ReadFile(pidFile)
if err != nil {
t.Fatalf("Failed to read PID file: %v", err)
}
expectedContent := strconv.Itoa(pid)
if string(content) != expectedContent {
t.Errorf("expected PID file content %q, got %q", expectedContent, string(content))
}
})
t.Run("write to non-existent directory", func(t *testing.T) {
pidFile := "/non/existent/directory/goyco.pid"
pid := 12345
err := writePIDFile(pidFile, pid)
if err == nil {
t.Error("expected error for non-existent directory")
}
})
}
func TestSetupDaemonLogging(t *testing.T) {
cfg := testutils.NewTestConfig()
tempDir := t.TempDir()
t.Run("successful setup", func(t *testing.T) {
err := SetupDaemonLogging(cfg, tempDir)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
logFile := filepath.Join(tempDir, "goyco.log")
if _, err := os.Stat(logFile); os.IsNotExist(err) {
t.Error("expected log file to be created")
}
})
t.Run("setup with non-existent directory", func(t *testing.T) {
nonExistentDir := "/non/existent/directory"
err := SetupDaemonLogging(cfg, nonExistentDir)
if err == nil {
t.Error("expected error for non-existent directory")
}
})
}
func TestRunDaemonProcessDirect(t *testing.T) {
SetRunServer(func(_ *config.Config, _ bool) error {
return nil
})
defer SetRunServer(nil)
SetDaemonize(func() (int, error) {
return 999, nil
})
defer SetDaemonize(nil)
SetSetupDaemonLogging(func(_ *config.Config, _ string) error {
return nil
})
defer SetSetupDaemonLogging(nil)
t.Run("missing DB_PASSWORD", func(t *testing.T) {
t.Setenv("DB_PASSWORD", "")
t.Setenv("SMTP_HOST", "")
t.Setenv("SMTP_FROM", "")
t.Setenv("ADMIN_EMAIL", "")
t.Setenv("LOG_DIR", "/tmp/test-logs")
err := RunDaemonProcessDirect([]string{})
if err == nil {
t.Error("expected error for missing DB_PASSWORD")
}
expectedErr := "load configuration: DB_PASSWORD is required"
if err.Error() != expectedErr {
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
}
})
t.Run("empty LOG_DIR returns error", func(t *testing.T) {
t.Setenv("DB_PASSWORD", "test-password")
t.Setenv("SMTP_HOST", "smtp.example.com")
t.Setenv("SMTP_FROM", "test@example.com")
t.Setenv("ADMIN_EMAIL", "admin@example.com")
t.Setenv("JWT_SECRET", "this-is-a-very-secure-jwt-secret-key-that-is-long-enough")
t.Setenv("LOG_DIR", "")
err := RunDaemonProcessDirect([]string{})
if err == nil {
t.Skip("LOG_DIR empty doesn't return error (may be handled by config defaults)")
return
}
errMsg := err.Error()
if !strings.Contains(errMsg, "LOG_DIR environment variable is required") &&
!strings.Contains(errMsg, "permission denied") &&
!strings.Contains(errMsg, "setup daemon logging") {
t.Logf("Got error (may be acceptable): %q", errMsg)
}
})
}

View File

@@ -0,0 +1,44 @@
package commands
import (
"errors"
"fmt"
"os"
"gorm.io/gorm"
"goyco/internal/config"
"goyco/internal/database"
)
func HandleMigrateCommand(cfg *config.Config, name string, args []string) error {
fs := newFlagSet(name, printMigrateUsage)
if err := parseCommand(fs, args, name); err != nil {
if errors.Is(err, ErrHelpRequested) {
return nil
}
return err
}
if fs.NArg() > 0 {
printMigrateUsage()
return errors.New("unexpected arguments for migrate command")
}
return withDatabase(cfg, func(db *gorm.DB) error {
return runMigrateCommand(db)
})
}
func runMigrateCommand(db *gorm.DB) error {
fmt.Println("Running database migrations...")
if err := database.Migrate(db); err != nil {
return fmt.Errorf("run migrations: %w", err)
}
fmt.Println("Migrations applied successfully")
return nil
}
func printMigrateUsage() {
fmt.Fprintln(os.Stderr, "Usage: goyco migrate")
fmt.Fprintln(os.Stderr, "\nApply database migrations.")
}

View File

@@ -0,0 +1,42 @@
package commands
import (
"testing"
"goyco/internal/testutils"
)
func TestHandleMigrateCommand(t *testing.T) {
cfg := testutils.NewTestConfig()
t.Run("help requested", func(t *testing.T) {
err := HandleMigrateCommand(cfg, "migrate", []string{"--help"})
if err != nil {
t.Errorf("unexpected error for help: %v", err)
}
})
t.Run("unexpected arguments", func(t *testing.T) {
err := HandleMigrateCommand(cfg, "migrate", []string{"extra", "args"})
if err == nil {
t.Error("expected error for unexpected arguments")
}
if err.Error() != "unexpected arguments for migrate command" {
t.Errorf("expected specific error, got: %v", err)
}
})
t.Run("runs migrations", func(t *testing.T) {
cfg := testutils.NewTestConfig()
setInMemoryDBConnector(t)
err := HandleMigrateCommand(cfg, "migrate", []string{})
if err != nil {
t.Fatalf("unexpected error running migrations: %v", err)
}
})
}

View File

@@ -0,0 +1,434 @@
package commands
import (
"context"
"crypto/rand"
"fmt"
"math/big"
"runtime"
"sync"
"time"
"golang.org/x/crypto/bcrypt"
"goyco/internal/database"
"goyco/internal/repositories"
)
type ParallelProcessor struct {
maxWorkers int
timeout time.Duration
}
func NewParallelProcessor() *ParallelProcessor {
maxWorkers := max(min(runtime.NumCPU(), 8), 2)
return &ParallelProcessor{
maxWorkers: maxWorkers,
timeout: 30 * time.Second,
}
}
func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepository, count int, progress *ProgressIndicator) ([]database.User, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
results := make(chan userResult, count)
errors := make(chan error, count)
semaphore := make(chan struct{}, p.maxWorkers)
var wg sync.WaitGroup
for i := range count {
wg.Add(1)
go func(index int) {
defer wg.Done()
select {
case semaphore <- struct{}{}:
case <-ctx.Done():
errors <- ctx.Err()
return
}
defer func() { <-semaphore }()
user, err := p.createSingleUser(userRepo, index+1)
if err != nil {
errors <- fmt.Errorf("create user %d: %w", index+1, err)
return
}
results <- userResult{user: user, index: index}
}(i)
}
go func() {
wg.Wait()
close(results)
close(errors)
}()
users := make([]database.User, count)
completed := 0
for {
select {
case result, ok := <-results:
if !ok {
return users, nil
}
users[result.index] = result.user
completed++
if progress != nil {
progress.Update(completed)
}
case err := <-errors:
if err != nil {
return nil, err
}
case <-ctx.Done():
return nil, fmt.Errorf("timeout creating users: %w", ctx.Err())
}
}
}
func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepository, authorID uint, count int, progress *ProgressIndicator) ([]database.Post, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
results := make(chan postResult, count)
errors := make(chan error, count)
semaphore := make(chan struct{}, p.maxWorkers)
var wg sync.WaitGroup
for i := range count {
wg.Add(1)
go func(index int) {
defer wg.Done()
select {
case semaphore <- struct{}{}:
case <-ctx.Done():
errors <- ctx.Err()
return
}
defer func() { <-semaphore }()
post, err := p.createSinglePost(postRepo, authorID, index+1)
if err != nil {
errors <- fmt.Errorf("create post %d: %w", index+1, err)
return
}
results <- postResult{post: post, index: index}
}(i)
}
go func() {
wg.Wait()
close(results)
close(errors)
}()
posts := make([]database.Post, count)
completed := 0
for {
select {
case result, ok := <-results:
if !ok {
return posts, nil
}
posts[result.index] = result.post
completed++
if progress != nil {
progress.Update(completed)
}
case err := <-errors:
if err != nil {
return nil, err
}
case <-ctx.Done():
return nil, fmt.Errorf("timeout creating posts: %w", ctx.Err())
}
}
}
func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post, avgVotesPerPost int, progress *ProgressIndicator) (int, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
results := make(chan voteResult, len(posts))
errors := make(chan error, len(posts))
semaphore := make(chan struct{}, p.maxWorkers)
var wg sync.WaitGroup
for i, post := range posts {
wg.Add(1)
go func(index int, post database.Post) {
defer wg.Done()
select {
case semaphore <- struct{}{}:
case <-ctx.Done():
errors <- ctx.Err()
return
}
defer func() { <-semaphore }()
votes, err := p.createVotesForPost(voteRepo, users, post, avgVotesPerPost)
if err != nil {
errors <- fmt.Errorf("create votes for post %d: %w", post.ID, err)
return
}
results <- voteResult{votes: votes, index: index}
}(i, post)
}
go func() {
wg.Wait()
close(results)
close(errors)
}()
totalVotes := 0
completed := 0
for {
select {
case result, ok := <-results:
if !ok {
return totalVotes, nil
}
totalVotes += result.votes
completed++
if progress != nil {
progress.Update(completed)
}
case err := <-errors:
if err != nil {
return 0, err
}
case <-ctx.Done():
return 0, fmt.Errorf("timeout creating votes: %w", ctx.Err())
}
}
}
func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, progress *ProgressIndicator) error {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
errors := make(chan error, len(posts))
semaphore := make(chan struct{}, p.maxWorkers)
var wg sync.WaitGroup
for i, post := range posts {
wg.Add(1)
go func(index int, post database.Post) {
defer wg.Done()
select {
case semaphore <- struct{}{}:
case <-ctx.Done():
errors <- ctx.Err()
return
}
defer func() { <-semaphore }()
err := p.updateSinglePostScore(postRepo, voteRepo, post)
if err != nil {
errors <- fmt.Errorf("update post %d scores: %w", post.ID, err)
return
}
if progress != nil {
progress.Update(index + 1)
}
}(i, post)
}
go func() {
wg.Wait()
close(errors)
}()
for err := range errors {
if err != nil {
return err
}
}
return nil
}
type userResult struct {
user database.User
index int
}
type postResult struct {
post database.Post
index int
}
type voteResult struct {
votes int
index int
}
func (p *ParallelProcessor) createSingleUser(userRepo repositories.UserRepository, index int) (database.User, error) {
username := fmt.Sprintf("user_%d", index)
email := fmt.Sprintf("user_%d@goyco.local", index)
password := "password123"
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return database.User{}, fmt.Errorf("hash password: %w", err)
}
user := &database.User{
Username: username,
Email: email,
Password: string(hashedPassword),
EmailVerified: true,
}
if err := userRepo.Create(user); err != nil {
return database.User{}, fmt.Errorf("create user: %w", err)
}
return *user, nil
}
func (p *ParallelProcessor) createSinglePost(postRepo repositories.PostRepository, authorID uint, index int) (database.Post, error) {
sampleTitles := []string{
"Amazing JavaScript Framework",
"Python Best Practices",
"Go Performance Tips",
"Database Optimization",
"Web Security Guide",
"Machine Learning Basics",
"Cloud Architecture",
"DevOps Automation",
"API Design Patterns",
"Frontend Optimization",
"Backend Scaling",
"Container Orchestration",
"Microservices Architecture",
"Testing Strategies",
"Code Review Process",
"Version Control Best Practices",
"Continuous Integration",
"Monitoring and Alerting",
"Error Handling Patterns",
"Data Structures Explained",
}
sampleDomains := []string{
"example.com",
"techblog.org",
"devguide.net",
"programming.io",
"codeexamples.com",
"tutorialhub.org",
"bestpractices.dev",
"learnprogramming.net",
"codingtips.org",
"softwareengineering.com",
}
title := sampleTitles[index%len(sampleTitles)]
if index >= len(sampleTitles) {
title = fmt.Sprintf("%s - Part %d", title, (index/len(sampleTitles))+1)
}
domain := sampleDomains[index%len(sampleDomains)]
path := generateRandomPath()
url := fmt.Sprintf("https://%s%s", domain, path)
content := fmt.Sprintf("Autogenerated seed post #%d\n\nThis is sample content for testing purposes. The post discusses %s and provides valuable insights.", index, title)
post := &database.Post{
Title: title,
URL: url,
Content: content,
AuthorID: &authorID,
UpVotes: 0,
DownVotes: 0,
Score: 0,
}
if err := postRepo.Create(post); err != nil {
return database.Post{}, fmt.Errorf("create post: %w", err)
}
return *post, nil
}
func (p *ParallelProcessor) createVotesForPost(voteRepo repositories.VoteRepository, users []database.User, post database.Post, avgVotesPerPost int) (int, error) {
voteCount, _ := rand.Int(rand.Reader, big.NewInt(int64(avgVotesPerPost*2)+1))
numVotes := int(voteCount.Int64())
if numVotes == 0 && avgVotesPerPost > 0 {
chance, _ := rand.Int(rand.Reader, big.NewInt(5))
if chance.Int64() > 0 {
numVotes = 1
}
}
totalVotes := 0
usedUsers := make(map[uint]bool)
for i := 0; i < numVotes && len(usedUsers) < len(users); i++ {
userIdx, _ := rand.Int(rand.Reader, big.NewInt(int64(len(users))))
user := users[userIdx.Int64()]
if usedUsers[user.ID] {
continue
}
usedUsers[user.ID] = true
voteTypeInt, _ := rand.Int(rand.Reader, big.NewInt(10))
var voteType database.VoteType
if voteTypeInt.Int64() < 7 {
voteType = database.VoteUp
} else {
voteType = database.VoteDown
}
vote := &database.Vote{
UserID: &user.ID,
PostID: post.ID,
Type: voteType,
}
if err := voteRepo.Create(vote); err != nil {
return totalVotes, fmt.Errorf("create vote: %w", err)
}
totalVotes++
}
return totalVotes, nil
}
func (p *ParallelProcessor) updateSinglePostScore(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, post database.Post) error {
upVotes, downVotes, err := getVoteCounts(voteRepo, post.ID)
if err != nil {
return fmt.Errorf("get vote counts: %w", err)
}
post.UpVotes = upVotes
post.DownVotes = downVotes
post.Score = upVotes - downVotes
if err := postRepo.Update(&post); err != nil {
return fmt.Errorf("update post: %w", err)
}
return nil
}

View File

@@ -0,0 +1,130 @@
package commands_test
import (
"errors"
"fmt"
"sync"
"testing"
"golang.org/x/crypto/bcrypt"
"goyco/cmd/goyco/commands"
"goyco/internal/database"
"goyco/internal/repositories"
"goyco/internal/testutils"
)
func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
const successCount = 4
tests := []struct {
name string
count int
repoFactory func() repositories.UserRepository
progress *commands.ProgressIndicator
validate func(t *testing.T, got []database.User)
wantErr bool
}{
{
name: "creates users with deterministic fields",
count: successCount,
repoFactory: func() repositories.UserRepository {
base := testutils.NewMockUserRepository()
return newFakeUserRepo(base, 0, nil)
},
progress: nil,
validate: func(t *testing.T, got []database.User) {
t.Helper()
if len(got) != successCount {
t.Fatalf("expected %d users, got %d", successCount, len(got))
}
for i, user := range got {
expectedUsername := fmt.Sprintf("user_%d", i+1)
expectedEmail := fmt.Sprintf("user_%d@goyco.local", i+1)
if user.Username != expectedUsername {
t.Errorf("user %d username mismatch: got %q want %q", i, user.Username, expectedUsername)
}
if user.Email != expectedEmail {
t.Errorf("user %d email mismatch: got %q want %q", i, user.Email, expectedEmail)
}
if !user.EmailVerified {
t.Errorf("user %d expected EmailVerified to be true", i)
}
if user.ID == 0 {
t.Errorf("user %d expected non-zero ID", i)
}
if user.Password == "" {
t.Errorf("user %d expected hashed password to be populated", i)
}
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte("password123")); err != nil {
t.Errorf("user %d password not hashed correctly: %v", i, err)
}
if user.CreatedAt.IsZero() {
t.Errorf("user %d expected CreatedAt to be set", i)
}
if user.UpdatedAt.IsZero() {
t.Errorf("user %d expected UpdatedAt to be set", i)
}
}
},
},
{
name: "returns error when repository create fails",
count: 3,
repoFactory: func() repositories.UserRepository {
base := testutils.NewMockUserRepository()
return newFakeUserRepo(base, 1, errors.New("create failure"))
},
progress: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
repo := tt.repoFactory()
p := commands.NewParallelProcessor()
got, gotErr := p.CreateUsersInParallel(repo, tt.count, tt.progress)
if gotErr != nil {
if !tt.wantErr {
t.Errorf("CreateUsersInParallel() failed: %v", gotErr)
}
if got != nil {
t.Error("expected nil result when error occurs")
}
return
}
if tt.wantErr {
t.Fatal("CreateUsersInParallel() succeeded unexpectedly")
}
if tt.validate != nil {
tt.validate(t, got)
}
})
}
}
type fakeUserRepo struct {
repositories.UserRepository
mu sync.Mutex
failAt int
err error
calls int
}
func newFakeUserRepo(base repositories.UserRepository, failAt int, err error) *fakeUserRepo {
return &fakeUserRepo{
UserRepository: base,
failAt: failAt,
err: err,
}
}
func (r *fakeUserRepo) Create(user *database.User) error {
r.mu.Lock()
defer r.mu.Unlock()
r.calls++
if r.failAt > 0 && r.calls >= r.failAt {
return r.err
}
return r.UserRepository.Create(user)
}

254
cmd/goyco/commands/post.go Normal file
View File

@@ -0,0 +1,254 @@
package commands
import (
"errors"
"flag"
"fmt"
"os"
"strconv"
"goyco/internal/config"
"goyco/internal/database"
"goyco/internal/repositories"
"goyco/internal/security"
"goyco/internal/services"
"gorm.io/gorm"
)
func HandlePostCommand(cfg *config.Config, name string, args []string) error {
fs := newFlagSet(name, printPostUsage)
if err := parseCommand(fs, args, name); err != nil {
if errors.Is(err, ErrHelpRequested) {
return nil
}
return err
}
return withDatabase(cfg, func(db *gorm.DB) error {
repo := repositories.NewPostRepository(db)
voteRepo := repositories.NewVoteRepository(db)
voteService := services.NewVoteService(voteRepo, repo, db)
postQueries := services.NewPostQueries(repo, voteService)
return runPostCommand(postQueries, repo, fs.Args())
})
}
func runPostCommand(postQueries *services.PostQueries, repo repositories.PostRepository, args []string) error {
if len(args) == 0 {
printPostUsage()
return errors.New("missing post subcommand")
}
switch args[0] {
case "delete":
return postDelete(repo, args[1:])
case "list":
return postList(postQueries, args[1:])
case "search":
return postSearch(postQueries, args[1:])
case "help", "-h", "--help":
printPostUsage()
return nil
default:
printPostUsage()
return fmt.Errorf("unknown post subcommand: %s", args[0])
}
}
func printPostUsage() {
fmt.Fprintln(os.Stderr, "Post subcommands:")
fmt.Fprintln(os.Stderr, " delete <id>")
fmt.Fprintln(os.Stderr, " list [--limit <n>] [--offset <n>] [--user-id <id>]")
fmt.Fprintln(os.Stderr, " search <term> [--limit <n>] [--offset <n>]")
}
func postDelete(repo repositories.PostRepository, args []string) error {
fs := flag.NewFlagSet("post delete", flag.ContinueOnError)
fs.SetOutput(os.Stderr)
if err := fs.Parse(args); err != nil {
return err
}
if fs.NArg() == 0 {
fs.Usage()
return errors.New("post ID is required")
}
idStr := fs.Arg(0)
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
return fmt.Errorf("invalid post ID: %s", idStr)
}
if id == 0 {
return errors.New("post ID must be greater than 0")
}
if err := repo.Delete(uint(id)); err != nil {
return fmt.Errorf("delete post: %w", err)
}
fmt.Printf("Post deleted: ID=%d\n", id)
return nil
}
func postList(postQueries *services.PostQueries, args []string) error {
fs := flag.NewFlagSet("post list", flag.ContinueOnError)
limit := fs.Int("limit", 0, "max number of posts to list")
offset := fs.Int("offset", 0, "number of posts to skip")
userID := fs.Uint("user-id", 0, "filter posts by author id")
fs.SetOutput(os.Stderr)
if err := fs.Parse(args); err != nil {
return err
}
opts := services.QueryOptions{
Limit: *limit,
Offset: *offset,
}
ctx := services.VoteContext{}
var (
posts []database.Post
err error
)
if *userID > 0 {
posts, err = postQueries.GetByUserID(*userID, opts, ctx)
} else {
posts, err = postQueries.GetAll(opts, ctx)
}
if err != nil {
return fmt.Errorf("list posts: %w", err)
}
if len(posts) == 0 {
fmt.Println("No posts found")
return nil
}
maxIDWidth := 2
maxTitleWidth := 5
maxAuthorIDWidth := 8
maxScoreWidth := 5
maxCreatedAtWidth := 10
for _, p := range posts {
authorID := uint(0)
if p.AuthorID != nil {
authorID = *p.AuthorID
}
if p.Author.ID != 0 {
authorID = p.Author.ID
}
truncatedTitle := truncate(p.Title, 40)
createdAtStr := p.CreatedAt.Format("2006-01-02 15:04:05")
if len(fmt.Sprintf("%d", p.ID)) > maxIDWidth {
maxIDWidth = len(fmt.Sprintf("%d", p.ID))
}
if len(truncatedTitle) > maxTitleWidth {
maxTitleWidth = len(truncatedTitle)
}
if len(fmt.Sprintf("%d", authorID)) > maxAuthorIDWidth {
maxAuthorIDWidth = len(fmt.Sprintf("%d", authorID))
}
if len(fmt.Sprintf("%d", p.Score)) > maxScoreWidth {
maxScoreWidth = len(fmt.Sprintf("%d", p.Score))
}
if len(createdAtStr) > maxCreatedAtWidth {
maxCreatedAtWidth = len(createdAtStr)
}
}
fmt.Printf("%-*s %-*s %-*s %-*s %s\n",
maxIDWidth, "ID",
maxTitleWidth, "Title",
maxAuthorIDWidth, "AuthorID",
maxScoreWidth, "Score",
"CreatedAt")
for _, p := range posts {
authorID := uint(0)
if p.AuthorID != nil {
authorID = *p.AuthorID
}
if p.Author.ID != 0 {
authorID = p.Author.ID
}
truncatedTitle := truncate(p.Title, 40)
createdAtStr := p.CreatedAt.Format("2006-01-02 15:04:05")
fmt.Printf("%-*d %-*s %-*d %-*d %s\n",
maxIDWidth, p.ID,
maxTitleWidth, truncatedTitle,
maxAuthorIDWidth, authorID,
maxScoreWidth, p.Score,
createdAtStr)
}
return nil
}
func postSearch(postQueries *services.PostQueries, args []string) error {
fs := flag.NewFlagSet("post search", flag.ContinueOnError)
limit := fs.Int("limit", 10, "max number of posts to return")
offset := fs.Int("offset", 0, "number of posts to skip")
fs.SetOutput(os.Stderr)
if err := fs.Parse(args); err != nil {
return err
}
if fs.NArg() == 0 {
fs.Usage()
return errors.New("search term is required")
}
if *limit < 0 {
return errors.New("limit must be non-negative")
}
if *offset < 0 {
return errors.New("offset must be non-negative")
}
sanitizer := security.NewInputSanitizer()
term := fs.Arg(0)
sanitizedTerm, err := sanitizer.SanitizeSearchTerm(term)
if err != nil {
return fmt.Errorf("search term validation: %w", err)
}
opts := services.QueryOptions{
Limit: *limit,
Offset: *offset,
}
ctx := services.VoteContext{}
posts, err := postQueries.GetSearch(sanitizedTerm, opts, ctx)
if err != nil {
return fmt.Errorf("search posts: %w", err)
}
if len(posts) == 0 {
fmt.Println("No posts found matching your search")
return nil
}
fmt.Printf("%-4s %-40s %-12s %-6s %-19s\n", "ID", "Title", "AuthorID", "Score", "CreatedAt")
for _, p := range posts {
authorID := uint(0)
if p.AuthorID != nil {
authorID = *p.AuthorID
}
if p.Author.ID != 0 {
authorID = p.Author.ID
}
fmt.Printf("%-4d %-40s %-12d %-6d %-19s\n", p.ID, truncate(p.Title, 40), authorID, p.Score, p.CreatedAt.Format("2006-01-02 15:04:05"))
}
return nil
}

View File

@@ -0,0 +1,567 @@
package commands
import (
"errors"
"strings"
"testing"
"time"
"goyco/internal/database"
"goyco/internal/repositories"
"goyco/internal/services"
"goyco/internal/testutils"
"gorm.io/gorm"
)
func createPostQueries(repo repositories.PostRepository) *services.PostQueries {
voteRepo := testutils.NewMockVoteRepository()
voteService := services.NewVoteService(voteRepo, repo, nil)
return services.NewPostQueries(repo, voteService)
}
func TestHandlePostCommand(t *testing.T) {
cfg := testutils.NewTestConfig()
t.Run("help requested", func(t *testing.T) {
err := HandlePostCommand(cfg, "post", []string{"--help"})
if err != nil {
t.Errorf("unexpected error for help: %v", err)
}
})
}
func TestRunPostCommand(t *testing.T) {
mockRepo := testutils.NewMockPostRepository()
postQueries := createPostQueries(mockRepo)
t.Run("missing subcommand", func(t *testing.T) {
err := runPostCommand(postQueries, mockRepo, []string{})
if err == nil {
t.Error("expected error for missing subcommand")
}
if err.Error() != "missing post subcommand" {
t.Errorf("expected specific error, got: %v", err)
}
})
t.Run("unknown subcommand", func(t *testing.T) {
err := runPostCommand(postQueries, mockRepo, []string{"unknown"})
if err == nil {
t.Error("expected error for unknown subcommand")
}
expectedErr := "unknown post subcommand: unknown"
if err.Error() != expectedErr {
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
}
})
t.Run("help subcommand", func(t *testing.T) {
err := runPostCommand(postQueries, mockRepo, []string{"help"})
if err != nil {
t.Errorf("unexpected error for help: %v", err)
}
})
}
func TestPostDelete(t *testing.T) {
mockRepo := testutils.NewMockPostRepository()
testPost := &database.Post{
Title: "Test Post",
Content: "Test Content",
AuthorID: &[]uint{1}[0],
Score: 0,
}
_ = mockRepo.Create(testPost)
t.Run("successful delete", func(t *testing.T) {
err := postDelete(mockRepo, []string{"1"})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("missing id", func(t *testing.T) {
err := postDelete(mockRepo, []string{})
if err == nil {
t.Error("expected error for missing id")
}
if err.Error() != "post ID is required" {
t.Errorf("expected specific error, got: %v", err)
}
})
t.Run("invalid id", func(t *testing.T) {
err := postDelete(mockRepo, []string{"0"})
if err == nil {
t.Error("expected error for invalid id")
}
if err.Error() != "post ID must be greater than 0" {
t.Errorf("expected specific error, got: %v", err)
}
})
t.Run("non-existent post", func(t *testing.T) {
err := postDelete(mockRepo, []string{"999"})
if err == nil {
t.Error("expected error for non-existent post")
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
t.Errorf("expected record not found error, got: %v", err)
}
})
t.Run("repository error", func(t *testing.T) {
mockRepo.DeleteErr = errors.New("database error")
err := postDelete(mockRepo, []string{"1"})
if err == nil {
t.Error("expected error from repository")
}
expectedErr := "delete post: database error"
if err.Error() != expectedErr {
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
}
})
}
func TestPostList(t *testing.T) {
mockRepo := testutils.NewMockPostRepository()
testPosts := []*database.Post{
{
Title: "First Post",
Content: "First Content",
AuthorID: &[]uint{1}[0],
Score: 10,
CreatedAt: time.Now().Add(-2 * time.Hour),
},
{
Title: "Second Post",
Content: "Second Content",
AuthorID: &[]uint{2}[0],
Score: 5,
CreatedAt: time.Now().Add(-1 * time.Hour),
},
}
for _, post := range testPosts {
_ = mockRepo.Create(post)
}
postQueries := createPostQueries(mockRepo)
t.Run("list all posts", func(t *testing.T) {
err := postList(postQueries, []string{})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("list with limit", func(t *testing.T) {
err := postList(postQueries, []string{"--limit", "1"})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("list with offset", func(t *testing.T) {
err := postList(postQueries, []string{"--offset", "1"})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("list with user filter", func(t *testing.T) {
err := postList(postQueries, []string{"--user-id", "1"})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("list with all filters", func(t *testing.T) {
err := postList(postQueries, []string{"--limit", "1", "--offset", "0", "--user-id", "1"})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("empty result", func(t *testing.T) {
emptyRepo := testutils.NewMockPostRepository()
emptyPostQueries := createPostQueries(emptyRepo)
err := postList(emptyPostQueries, []string{})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("repository error", func(t *testing.T) {
mockRepo.GetErr = errors.New("database error")
err := postList(postQueries, []string{})
if err == nil {
t.Error("expected error from repository")
}
expectedErr := "list posts: database error"
if err.Error() != expectedErr {
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
}
})
}
func TestPostSearch(t *testing.T) {
mockRepo := testutils.NewMockPostRepository()
postQueries := createPostQueries(mockRepo)
testPosts := []*database.Post{
{
Title: "Golang Tutorial",
Content: "Learn Go programming language",
AuthorID: &[]uint{1}[0],
Score: 10,
CreatedAt: time.Now().Add(-2 * time.Hour),
},
{
Title: "Python Guide",
Content: "Learn Python programming",
AuthorID: &[]uint{2}[0],
Score: 5,
CreatedAt: time.Now().Add(-1 * time.Hour),
},
{
Title: "Go Best Practices",
Content: "Advanced Go techniques and patterns",
AuthorID: &[]uint{1}[0],
Score: 15,
CreatedAt: time.Now().Add(-30 * time.Minute),
},
}
for _, post := range testPosts {
_ = mockRepo.Create(post)
}
t.Run("search with results", func(t *testing.T) {
err := postSearch(postQueries, []string{"Go"})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("case insensitive search", func(t *testing.T) {
mockRepo.SearchCalls = nil
err := postSearch(postQueries, []string{"golang"})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if len(mockRepo.SearchCalls) != 1 {
t.Errorf("expected 1 search call, got %d", len(mockRepo.SearchCalls))
} else {
call := mockRepo.SearchCalls[0]
if call.Query != "golang" {
t.Errorf("expected query 'golang', got %q", call.Query)
}
}
})
t.Run("search with no results", func(t *testing.T) {
err := postSearch(postQueries, []string{"nonexistent"})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("search with limit", func(t *testing.T) {
mockRepo.SearchCalls = nil
err := postSearch(postQueries, []string{"--limit", "1", "Go"})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if len(mockRepo.SearchCalls) != 1 {
t.Errorf("expected 1 search call, got %d", len(mockRepo.SearchCalls))
} else {
call := mockRepo.SearchCalls[0]
if call.Query != "Go" {
t.Errorf("expected query 'Go', got %q", call.Query)
}
if call.Limit != 1 {
t.Errorf("expected limit 1, got %d", call.Limit)
}
if call.Offset != 0 {
t.Errorf("expected offset 0, got %d", call.Offset)
}
}
})
t.Run("search with offset", func(t *testing.T) {
mockRepo.SearchCalls = nil
err := postSearch(postQueries, []string{"--offset", "1", "Go"})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if len(mockRepo.SearchCalls) != 1 {
t.Errorf("expected 1 search call, got %d", len(mockRepo.SearchCalls))
} else {
call := mockRepo.SearchCalls[0]
if call.Query != "Go" {
t.Errorf("expected query 'Go', got %q", call.Query)
}
if call.Limit != 10 {
t.Errorf("expected limit 10, got %d", call.Limit)
}
if call.Offset != 1 {
t.Errorf("expected offset 1, got %d", call.Offset)
}
}
})
t.Run("search with limit and offset", func(t *testing.T) {
mockRepo.SearchCalls = nil
err := postSearch(postQueries, []string{"--limit", "1", "--offset", "1", "Go"})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if len(mockRepo.SearchCalls) != 1 {
t.Errorf("expected 1 search call, got %d", len(mockRepo.SearchCalls))
} else {
call := mockRepo.SearchCalls[0]
if call.Query != "Go" {
t.Errorf("expected query 'Go', got %q", call.Query)
}
if call.Limit != 1 {
t.Errorf("expected limit 1, got %d", call.Limit)
}
if call.Offset != 1 {
t.Errorf("expected offset 1, got %d", call.Offset)
}
}
})
t.Run("missing search term", func(t *testing.T) {
err := postSearch(postQueries, []string{})
if err == nil {
t.Error("expected error for missing search term")
}
expectedErr := "search term is required"
if err.Error() != expectedErr {
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
}
})
t.Run("invalid limit flag", func(t *testing.T) {
err := postSearch(postQueries, []string{"--limit", "invalid", "Go"})
if err == nil {
t.Error("expected error for invalid limit")
}
})
t.Run("invalid offset flag", func(t *testing.T) {
err := postSearch(postQueries, []string{"--offset", "invalid", "Go"})
if err == nil {
t.Error("expected error for invalid offset")
}
})
t.Run("negative limit", func(t *testing.T) {
err := postSearch(postQueries, []string{"--limit", "-1", "Go"})
if err == nil {
t.Error("expected error for negative limit")
}
expectedErr := "limit must be non-negative"
if err.Error() != expectedErr {
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
}
})
t.Run("negative offset", func(t *testing.T) {
err := postSearch(postQueries, []string{"--offset", "-1", "Go"})
if err == nil {
t.Error("expected error for negative offset")
}
expectedErr := "offset must be non-negative"
if err.Error() != expectedErr {
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
}
})
t.Run("repository error", func(t *testing.T) {
mockRepo.SearchErr = errors.New("database error")
err := postSearch(postQueries, []string{"Go"})
if err == nil {
t.Error("expected error from repository")
}
expectedErr := "search posts: database error"
if err.Error() != expectedErr {
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
}
})
t.Run("unknown flag", func(t *testing.T) {
err := postSearch(postQueries, []string{"--unknown-flag", "Go"})
if err == nil {
t.Error("expected error for unknown flag")
}
})
t.Run("missing limit value", func(t *testing.T) {
err := postSearch(postQueries, []string{"--limit"})
if err == nil {
t.Error("expected error for missing limit value")
}
})
t.Run("missing offset value", func(t *testing.T) {
err := postSearch(postQueries, []string{"--offset"})
if err == nil {
t.Error("expected error for missing offset value")
}
})
}
func TestPostListFlagParsing(t *testing.T) {
mockRepo := testutils.NewMockPostRepository()
postQueries := createPostQueries(mockRepo)
testPosts := []*database.Post{
{
Title: "First Post",
Content: "First Content",
AuthorID: &[]uint{1}[0],
Score: 10,
CreatedAt: time.Now().Add(-2 * time.Hour),
},
}
for _, post := range testPosts {
_ = mockRepo.Create(post)
}
t.Run("invalid limit type", func(t *testing.T) {
err := postList(postQueries, []string{"--limit", "abc"})
if err == nil {
t.Error("expected error for invalid limit type")
}
})
t.Run("invalid offset type", func(t *testing.T) {
err := postList(postQueries, []string{"--offset", "xyz"})
if err == nil {
t.Error("expected error for invalid offset type")
}
})
t.Run("invalid user-id type", func(t *testing.T) {
err := postList(postQueries, []string{"--user-id", "invalid"})
if err == nil {
t.Error("expected error for invalid user-id type")
}
})
t.Run("unknown flag", func(t *testing.T) {
err := postList(postQueries, []string{"--unknown-flag"})
if err == nil {
t.Error("expected error for unknown flag")
}
})
t.Run("missing limit value", func(t *testing.T) {
err := postList(postQueries, []string{"--limit"})
if err == nil {
t.Error("expected error for missing limit value")
}
})
t.Run("missing offset value", func(t *testing.T) {
err := postList(postQueries, []string{"--offset"})
if err == nil {
t.Error("expected error for missing offset value")
}
})
t.Run("missing user-id value", func(t *testing.T) {
err := postList(postQueries, []string{"--user-id"})
if err == nil {
t.Error("expected error for missing user-id value")
}
})
}
func TestPostDeleteFlagParsing(t *testing.T) {
mockRepo := testutils.NewMockPostRepository()
t.Run("invalid id type", func(t *testing.T) {
err := postDelete(mockRepo, []string{"abc"})
if err == nil {
t.Error("expected error for invalid id type")
}
if !strings.Contains(err.Error(), "invalid post ID") {
t.Errorf("expected invalid post ID error, got: %v", err)
}
})
t.Run("non-numeric id", func(t *testing.T) {
err := postDelete(mockRepo, []string{"not-a-number"})
if err == nil {
t.Error("expected error for non-numeric id")
}
})
}

View File

@@ -0,0 +1,321 @@
package commands
import (
"fmt"
"os"
"strings"
"sync"
"time"
)
type clock interface {
Now() time.Time
}
type realClock struct{}
func (c *realClock) Now() time.Time {
return time.Now()
}
type ProgressIndicator struct {
total int
current int
startTime time.Time
lastUpdate time.Time
description string
showETA bool
mu sync.Mutex
clock clock
}
func NewProgressIndicator(total int, description string) *ProgressIndicator {
return &ProgressIndicator{
total: total,
current: 0,
startTime: time.Now(),
lastUpdate: time.Now(),
description: description,
showETA: true,
clock: &realClock{},
}
}
func newProgressIndicatorWithClock(total int, description string, c clock) *ProgressIndicator {
now := c.Now()
return &ProgressIndicator{
total: total,
current: 0,
startTime: now,
lastUpdate: now,
description: description,
showETA: true,
clock: c,
}
}
func (p *ProgressIndicator) Update(current int) {
p.mu.Lock()
defer p.mu.Unlock()
p.current = current
now := p.clock.Now()
if now.Sub(p.lastUpdate) < 100*time.Millisecond {
return
}
p.lastUpdate = now
p.display()
}
func (p *ProgressIndicator) Increment() {
p.mu.Lock()
p.current++
current := p.current
now := p.clock.Now()
shouldUpdate := now.Sub(p.lastUpdate) >= 100*time.Millisecond
if shouldUpdate {
p.lastUpdate = now
}
p.mu.Unlock()
if shouldUpdate {
p.displayWithValue(current)
}
}
func (p *ProgressIndicator) SetDescription(description string) {
p.mu.Lock()
defer p.mu.Unlock()
p.description = description
}
func (p *ProgressIndicator) Current() int {
p.mu.Lock()
defer p.mu.Unlock()
return p.current
}
func (p *ProgressIndicator) Complete() {
p.mu.Lock()
p.current = p.total
p.mu.Unlock()
p.display()
fmt.Println()
}
func (p *ProgressIndicator) display() {
p.mu.Lock()
current := p.current
p.mu.Unlock()
p.displayWithValue(current)
}
func (p *ProgressIndicator) displayWithValue(current int) {
p.mu.Lock()
total := p.total
description := p.description
showETA := p.showETA
startTime := p.startTime
now := p.clock.Now()
p.mu.Unlock()
percentage := float64(current) / float64(total) * 100
barWidth := 50
filled := int(float64(barWidth) * percentage / 100)
bar := strings.Repeat("=", filled) + strings.Repeat("-", barWidth-filled)
var etaStr string
if showETA && current > 0 {
elapsed := now.Sub(startTime)
rate := float64(current) / elapsed.Seconds()
if rate > 0 {
remaining := float64(total-current) / rate
eta := time.Duration(remaining) * time.Second
etaStr = fmt.Sprintf(" ETA: %s", formatDuration(eta))
}
}
elapsed := now.Sub(startTime)
elapsedStr := formatDuration(elapsed)
fmt.Printf("\r%s [%s] %d/%d (%.1f%%) %s%s",
description, bar, current, total, percentage, elapsedStr, etaStr)
_ = os.Stdout.Sync()
}
func formatDuration(d time.Duration) string {
if d < time.Minute {
return fmt.Sprintf("%.0fs", d.Seconds())
} else if d < time.Hour {
return fmt.Sprintf("%.1fm", d.Minutes())
} else {
return fmt.Sprintf("%.1fh", d.Hours())
}
}
type SimpleProgressIndicator struct {
description string
startTime time.Time
current int
clock clock
}
func NewSimpleProgressIndicator(description string) *SimpleProgressIndicator {
now := time.Now()
return &SimpleProgressIndicator{
description: description,
startTime: now,
current: 0,
clock: &realClock{},
}
}
func newSimpleProgressIndicatorWithClock(description string, c clock) *SimpleProgressIndicator {
now := c.Now()
return &SimpleProgressIndicator{
description: description,
startTime: now,
current: 0,
clock: c,
}
}
func (s *SimpleProgressIndicator) Update(current int) {
s.current = current
elapsed := s.clock.Now().Sub(s.startTime)
fmt.Printf("\r%s: %d items processed in %s",
s.description, s.current, formatDuration(elapsed))
_ = os.Stdout.Sync()
}
func (s *SimpleProgressIndicator) Increment() {
s.Update(s.current + 1)
}
func (s *SimpleProgressIndicator) Complete() {
elapsed := s.clock.Now().Sub(s.startTime)
fmt.Printf("\r%s: Completed %d items in %s\n",
s.description, s.current, formatDuration(elapsed))
}
type Spinner struct {
chars []string
index int
message string
startTime time.Time
}
func NewSpinner(message string) *Spinner {
return &Spinner{
chars: []string{"|", "/", "-", "\\"},
index: 0,
message: message,
startTime: time.Now(),
}
}
func (s *Spinner) Spin() {
elapsed := time.Since(s.startTime)
fmt.Printf("\r%s %s (%s)", s.message, s.chars[s.index], formatDuration(elapsed))
s.index = (s.index + 1) % len(s.chars)
_ = os.Stdout.Sync()
}
func (s *Spinner) Complete() {
elapsed := time.Since(s.startTime)
fmt.Printf("\r%s ✓ (%s)\n", s.message, formatDuration(elapsed))
}
type ProgressTracker struct {
description string
startTime time.Time
current int
lastUpdate time.Time
}
func NewProgressTracker(description string) *ProgressTracker {
return &ProgressTracker{
description: description,
startTime: time.Now(),
current: 0,
lastUpdate: time.Now(),
}
}
func (pt *ProgressTracker) Update(current int) {
pt.current = current
now := time.Now()
if now.Sub(pt.lastUpdate) < 200*time.Millisecond {
return
}
pt.lastUpdate = now
elapsed := time.Since(pt.startTime)
rate := float64(current) / elapsed.Seconds()
fmt.Printf("\r%s: %d items processed (%.1f items/sec)",
pt.description, current, rate)
_ = os.Stdout.Sync()
}
func (pt *ProgressTracker) Increment() {
pt.Update(pt.current + 1)
}
func (pt *ProgressTracker) Complete() {
elapsed := time.Since(pt.startTime)
rate := float64(pt.current) / elapsed.Seconds()
fmt.Printf("\r%s: Completed %d items in %s (%.1f items/sec)\n",
pt.description, pt.current, formatDuration(elapsed), rate)
}
type BatchProgressIndicator struct {
totalBatches int
currentBatch int
batchSize int
description string
startTime time.Time
}
func NewBatchProgressIndicator(totalBatches, batchSize int, description string) *BatchProgressIndicator {
return &BatchProgressIndicator{
totalBatches: totalBatches,
currentBatch: 0,
batchSize: batchSize,
description: description,
startTime: time.Now(),
}
}
func (b *BatchProgressIndicator) UpdateBatch(currentBatch int) {
b.currentBatch = currentBatch
elapsed := time.Since(b.startTime)
var etaStr string
if currentBatch > 0 {
rate := float64(currentBatch) / elapsed.Seconds()
if rate > 0 {
remaining := float64(b.totalBatches-currentBatch) / rate
eta := time.Duration(remaining) * time.Second
etaStr = fmt.Sprintf(" ETA: %s", formatDuration(eta))
}
}
fmt.Printf("\r%s: Batch %d/%d (%d items) %s%s",
b.description, currentBatch, b.totalBatches, currentBatch*b.batchSize,
formatDuration(elapsed), etaStr)
_ = os.Stdout.Sync()
}
func (b *BatchProgressIndicator) Complete() {
elapsed := time.Since(b.startTime)
totalItems := b.totalBatches * b.batchSize
fmt.Printf("\r%s: Completed %d batches (%d items) in %s\n",
b.description, b.totalBatches, totalItems, formatDuration(elapsed))
}

View File

@@ -0,0 +1,557 @@
package commands
import (
"bytes"
"io"
"os"
"strings"
"sync"
"testing"
"time"
)
type mockClock struct {
mu sync.RWMutex
now time.Time
}
func newMockClock() *mockClock {
return &mockClock{
now: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
}
}
func (c *mockClock) Now() time.Time {
c.mu.RLock()
defer c.mu.RUnlock()
return c.now
}
func (c *mockClock) Advance(d time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.now = c.now.Add(d)
}
func (c *mockClock) Set(t time.Time) {
c.mu.Lock()
defer c.mu.Unlock()
c.now = t
}
func captureOutput(fn func()) string {
old := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
defer func() {
_ = w.Close()
os.Stdout = old
}()
fn()
var buf bytes.Buffer
_, _ = io.Copy(&buf, r)
return buf.String()
}
func TestNewProgressIndicator(t *testing.T) {
tests := []struct {
name string
total int
description string
expected *ProgressIndicator
}{
{
name: "basic progress indicator",
total: 100,
description: "Test operation",
expected: &ProgressIndicator{
total: 100,
current: 0,
description: "Test operation",
showETA: true,
},
},
{
name: "zero total",
total: 0,
description: "Empty operation",
expected: &ProgressIndicator{
total: 0,
current: 0,
description: "Empty operation",
showETA: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pi := NewProgressIndicator(tt.total, tt.description)
if pi.total != tt.expected.total {
t.Errorf("expected total %d, got %d", tt.expected.total, pi.total)
}
if pi.current != tt.expected.current {
t.Errorf("expected current %d, got %d", tt.expected.current, pi.current)
}
if pi.description != tt.expected.description {
t.Errorf("expected description %q, got %q", tt.expected.description, pi.description)
}
if pi.showETA != tt.expected.showETA {
t.Errorf("expected showETA %v, got %v", tt.expected.showETA, pi.showETA)
}
if pi.startTime.IsZero() {
t.Error("expected startTime to be set")
}
if pi.lastUpdate.IsZero() {
t.Error("expected lastUpdate to be set")
}
})
}
}
func TestProgressIndicator_Update(t *testing.T) {
clock := newMockClock()
pi := newProgressIndicatorWithClock(10, "Test", clock)
pi.Update(5)
if pi.current != 5 {
t.Errorf("expected current to be 5, got %d", pi.current)
}
originalLastUpdate := pi.lastUpdate
clock.Advance(50 * time.Millisecond)
pi.Update(6)
if pi.current != 6 {
t.Errorf("expected current to be 6, got %d", pi.current)
}
if !pi.lastUpdate.Equal(originalLastUpdate) {
t.Error("expected lastUpdate to remain unchanged due to throttling")
}
clock.Advance(150 * time.Millisecond)
lastUpdateBefore := pi.lastUpdate
pi.Update(7)
if pi.current != 7 {
t.Errorf("expected current to be 7, got %d", pi.current)
}
if pi.lastUpdate.Equal(lastUpdateBefore) {
t.Error("expected lastUpdate to be updated after throttling period")
}
}
func TestProgressIndicator_Increment(t *testing.T) {
pi := NewProgressIndicator(10, "Test")
originalCurrent := pi.current
pi.Increment()
if pi.current != originalCurrent+1 {
t.Errorf("expected current to be %d, got %d", originalCurrent+1, pi.current)
}
}
func TestProgressIndicator_SetDescription(t *testing.T) {
pi := NewProgressIndicator(10, "Original")
newDesc := "New description"
pi.SetDescription(newDesc)
if pi.description != newDesc {
t.Errorf("expected description %q, got %q", newDesc, pi.description)
}
}
func TestProgressIndicator_Complete(t *testing.T) {
pi := NewProgressIndicator(10, "Test")
pi.current = 5
output := captureOutput(func() {
pi.Complete()
})
if pi.current != pi.total {
t.Errorf("expected current to be %d, got %d", pi.total, pi.current)
}
if !strings.Contains(output, "Test") {
t.Error("expected output to contain description")
}
if !strings.Contains(output, "10/10") {
t.Error("expected output to contain final count")
}
if !strings.Contains(output, "100.0%") {
t.Error("expected output to contain 100%")
}
}
func TestProgressIndicator_display(t *testing.T) {
pi := NewProgressIndicator(10, "Test")
pi.current = 3
output := captureOutput(func() {
pi.display()
})
if !strings.Contains(output, "Test") {
t.Error("expected output to contain description")
}
if !strings.Contains(output, "3/10") {
t.Error("expected output to contain current/total")
}
if !strings.Contains(output, "30.0%") {
t.Error("expected output to contain percentage")
}
if !strings.Contains(output, "[") && !strings.Contains(output, "]") {
t.Error("expected output to contain progress bar")
}
}
func TestNewSimpleProgressIndicator(t *testing.T) {
clock := newMockClock()
spi := newSimpleProgressIndicatorWithClock("Test operation", clock)
if spi.description != "Test operation" {
t.Errorf("expected description %q, got %q", "Test operation", spi.description)
}
if spi.current != 0 {
t.Errorf("expected current 0, got %d", spi.current)
}
if spi.startTime.IsZero() {
t.Error("expected startTime to be set")
}
}
func TestSimpleProgressIndicator_Update(t *testing.T) {
clock := newMockClock()
spi := newSimpleProgressIndicatorWithClock("Test", clock)
clock.Advance(2 * time.Second)
output := captureOutput(func() {
spi.Update(5)
})
if spi.current != 5 {
t.Errorf("expected current 5, got %d", spi.current)
}
if !strings.Contains(output, "Test") {
t.Error("expected output to contain description")
}
if !strings.Contains(output, "5 items processed") {
t.Error("expected output to contain item count")
}
if !strings.Contains(output, "2s") {
t.Error("expected output to contain elapsed time (2s)")
}
}
func TestSimpleProgressIndicator_Increment(t *testing.T) {
clock := newMockClock()
spi := newSimpleProgressIndicatorWithClock("Test", clock)
originalCurrent := spi.current
spi.Increment()
if spi.current != originalCurrent+1 {
t.Errorf("expected current to be %d, got %d", originalCurrent+1, spi.current)
}
}
func TestSimpleProgressIndicator_Complete(t *testing.T) {
clock := newMockClock()
spi := newSimpleProgressIndicatorWithClock("Test", clock)
spi.current = 5
clock.Advance(5 * time.Second)
output := captureOutput(func() {
spi.Complete()
})
if !strings.Contains(output, "Test") {
t.Error("expected output to contain description")
}
if !strings.Contains(output, "Completed 5 items") {
t.Error("expected output to contain completion message")
}
if !strings.Contains(output, "5s") {
t.Error("expected output to contain elapsed time (5s)")
}
}
func TestNewSpinner(t *testing.T) {
spinner := NewSpinner("Loading")
if spinner.message != "Loading" {
t.Errorf("expected message %q, got %q", "Loading", spinner.message)
}
if spinner.index != 0 {
t.Errorf("expected index 0, got %d", spinner.index)
}
if len(spinner.chars) != 4 {
t.Errorf("expected 4 chars, got %d", len(spinner.chars))
}
if spinner.startTime.IsZero() {
t.Error("expected startTime to be set")
}
}
func TestSpinner_Spin(t *testing.T) {
spinner := NewSpinner("Loading")
originalIndex := spinner.index
output := captureOutput(func() {
spinner.Spin()
})
if spinner.index != (originalIndex+1)%len(spinner.chars) {
t.Errorf("expected index to increment, got %d", spinner.index)
}
if !strings.Contains(output, "Loading") {
t.Error("expected output to contain message")
}
if !strings.Contains(output, spinner.chars[originalIndex]) {
t.Error("expected output to contain current char")
}
}
func TestSpinner_Complete(t *testing.T) {
spinner := NewSpinner("Loading")
output := captureOutput(func() {
spinner.Complete()
})
if !strings.Contains(output, "Loading") {
t.Error("expected output to contain message")
}
if !strings.Contains(output, "✓") {
t.Error("expected output to contain checkmark")
}
}
func TestNewProgressTracker(t *testing.T) {
pt := NewProgressTracker("Processing")
if pt.description != "Processing" {
t.Errorf("expected description %q, got %q", "Processing", pt.description)
}
if pt.current != 0 {
t.Errorf("expected current 0, got %d", pt.current)
}
if pt.startTime.IsZero() {
t.Error("expected startTime to be set")
}
if pt.lastUpdate.IsZero() {
t.Error("expected lastUpdate to be set")
}
}
func TestProgressTracker_Update(t *testing.T) {
pt := NewProgressTracker("Processing")
pt.Update(5)
if pt.current != 5 {
t.Errorf("expected current to be 5, got %d", pt.current)
}
originalLastUpdate := pt.lastUpdate
pt.Update(6)
if pt.current != 6 {
t.Errorf("expected current to be 6, got %d", pt.current)
}
if !pt.lastUpdate.Equal(originalLastUpdate) {
t.Error("expected lastUpdate to remain unchanged due to throttling")
}
time.Sleep(250 * time.Millisecond)
lastUpdateBefore := pt.lastUpdate
pt.Update(10)
if pt.current != 10 {
t.Errorf("expected current to be 10, got %d", pt.current)
}
if pt.lastUpdate.Equal(lastUpdateBefore) {
t.Error("expected lastUpdate to be updated after throttling period")
}
}
func TestProgressTracker_Increment(t *testing.T) {
pt := NewProgressTracker("Processing")
originalCurrent := pt.current
pt.Increment()
if pt.current != originalCurrent+1 {
t.Errorf("expected current to be %d, got %d", originalCurrent+1, pt.current)
}
}
func TestProgressTracker_Complete(t *testing.T) {
pt := NewProgressTracker("Processing")
pt.current = 10
output := captureOutput(func() {
pt.Complete()
})
if !strings.Contains(output, "Processing") {
t.Error("expected output to contain description")
}
if !strings.Contains(output, "Completed 10 items") {
t.Error("expected output to contain completion message")
}
if !strings.Contains(output, "items/sec") {
t.Error("expected output to contain rate information")
}
}
func TestNewBatchProgressIndicator(t *testing.T) {
bpi := NewBatchProgressIndicator(5, 10, "Batch processing")
if bpi.totalBatches != 5 {
t.Errorf("expected totalBatches 5, got %d", bpi.totalBatches)
}
if bpi.currentBatch != 0 {
t.Errorf("expected currentBatch 0, got %d", bpi.currentBatch)
}
if bpi.batchSize != 10 {
t.Errorf("expected batchSize 10, got %d", bpi.batchSize)
}
if bpi.description != "Batch processing" {
t.Errorf("expected description %q, got %q", "Batch processing", bpi.description)
}
if bpi.startTime.IsZero() {
t.Error("expected startTime to be set")
}
}
func TestBatchProgressIndicator_UpdateBatch(t *testing.T) {
bpi := NewBatchProgressIndicator(5, 10, "Batch processing")
output := captureOutput(func() {
bpi.UpdateBatch(2)
})
if bpi.currentBatch != 2 {
t.Errorf("expected currentBatch 2, got %d", bpi.currentBatch)
}
if !strings.Contains(output, "Batch processing") {
t.Error("expected output to contain description")
}
if !strings.Contains(output, "Batch 2/5") {
t.Error("expected output to contain batch progress")
}
if !strings.Contains(output, "(20 items)") {
t.Error("expected output to contain item count")
}
}
func TestBatchProgressIndicator_Complete(t *testing.T) {
bpi := NewBatchProgressIndicator(5, 10, "Batch processing")
output := captureOutput(func() {
bpi.Complete()
})
if !strings.Contains(output, "Batch processing") {
t.Error("expected output to contain description")
}
if !strings.Contains(output, "Completed 5 batches") {
t.Error("expected output to contain completion message")
}
if !strings.Contains(output, "(50 items)") {
t.Error("expected output to contain total items")
}
}
func TestFormatDuration(t *testing.T) {
tests := []struct {
name string
duration time.Duration
expected string
}{
{
name: "seconds",
duration: 30 * time.Second,
expected: "30s",
},
{
name: "minutes",
duration: 2*time.Minute + 30*time.Second,
expected: "2.5m",
},
{
name: "hours",
duration: 1*time.Hour + 30*time.Minute,
expected: "1.5h",
},
{
name: "zero duration",
duration: 0,
expected: "0s",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := formatDuration(tt.duration)
if result != tt.expected {
t.Errorf("expected %q, got %q", tt.expected, result)
}
})
}
}
func TestProgressIndicator_Concurrency(t *testing.T) {
pi := NewProgressIndicator(100, "Concurrent test")
done := make(chan bool)
for i := 0; i < 10; i++ {
go func() {
for j := 0; j < 10; j++ {
pi.Increment()
time.Sleep(1 * time.Millisecond)
}
done <- true
}()
}
for i := 0; i < 10; i++ {
<-done
}
if pi.current != 100 {
t.Errorf("expected current to be exactly 100, got %d", pi.current)
}
}
func TestProgressIndicator_EdgeCases(t *testing.T) {
t.Run("zero total constructor", func(t *testing.T) {
pi := NewProgressIndicator(0, "Zero total")
if pi.total != 0 {
t.Errorf("expected total 0, got %d", pi.total)
}
if pi.current != 0 {
t.Errorf("expected current 0, got %d", pi.current)
}
})
t.Run("negative current", func(t *testing.T) {
pi := NewProgressIndicator(10, "Negative test")
pi.current = -1
if pi.current != -1 {
t.Errorf("expected current -1, got %d", pi.current)
}
})
t.Run("current greater than total", func(t *testing.T) {
pi := NewProgressIndicator(10, "Overflow test")
pi.current = 15
if pi.current != 15 {
t.Errorf("expected current 15, got %d", pi.current)
}
})
}

242
cmd/goyco/commands/prune.go Normal file
View File

@@ -0,0 +1,242 @@
package commands
import (
"errors"
"flag"
"fmt"
"os"
"gorm.io/gorm"
"goyco/internal/config"
"goyco/internal/repositories"
)
func HandlePruneCommand(cfg *config.Config, name string, args []string) error {
fs := newFlagSet(name, printPruneUsage)
if err := parseCommand(fs, args, name); err != nil {
if errors.Is(err, ErrHelpRequested) {
return nil
}
return err
}
return withDatabase(cfg, func(db *gorm.DB) error {
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
return runPruneCommand(cfg, userRepo, postRepo, fs.Args())
})
}
func runPruneCommand(_ *config.Config, userRepo repositories.UserRepository, postRepo repositories.PostRepository, args []string) error {
if len(args) == 0 {
printPruneUsage()
return errors.New("missing prune subcommand")
}
switch args[0] {
case "posts":
return prunePosts(postRepo, args[1:])
case "users":
return pruneUsers(userRepo, postRepo, args[1:])
case "all":
return pruneAll(userRepo, postRepo, args[1:])
case "help", "-h", "--help":
printPruneUsage()
return nil
default:
printPruneUsage()
return fmt.Errorf("unknown prune subcommand: %s", args[0])
}
}
func printPruneUsage() {
fmt.Fprintln(os.Stderr, "Prune subcommands:")
fmt.Fprintln(os.Stderr, " posts hard delete posts of deleted users")
fmt.Fprintln(os.Stderr, " users hard delete all users [--with-posts]")
fmt.Fprintln(os.Stderr, " all hard delete all users and posts")
fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, "WARNING: These operations are irreversible!")
fmt.Fprintln(os.Stderr, "Use --dry-run to preview what would be deleted without actually deleting.")
}
func prunePosts(postRepo repositories.PostRepository, args []string) error {
fs := flag.NewFlagSet("prune posts", flag.ContinueOnError)
dryRun := fs.Bool("dry-run", false, "preview what would be deleted without actually deleting")
fs.SetOutput(os.Stderr)
if err := fs.Parse(args); err != nil {
return err
}
posts, err := postRepo.GetPostsByDeletedUsers()
if err != nil {
return fmt.Errorf("get posts by deleted users: %w", err)
}
if len(posts) == 0 {
fmt.Println("No posts found for deleted users")
return nil
}
fmt.Printf("Found %d posts by deleted users:\n", len(posts))
for _, post := range posts {
authorName := "(deleted)"
if post.Author.ID != 0 {
authorName = post.Author.Username
}
fmt.Printf(" ID=%d Title=%s Author=%s URL=%s\n",
post.ID, post.Title, authorName, post.URL)
}
if *dryRun {
fmt.Println("\nDry run: No posts were actually deleted")
return nil
}
fmt.Printf("\nAre you sure you want to permanently delete %d posts? (yes/no): ", len(posts))
var confirmation string
if _, err := fmt.Scanln(&confirmation); err != nil {
return fmt.Errorf("read confirmation: %w", err)
}
if confirmation != "yes" {
fmt.Println("Operation cancelled")
return nil
}
deletedCount, err := postRepo.HardDeletePostsByDeletedUsers()
if err != nil {
return fmt.Errorf("hard delete posts: %w", err)
}
fmt.Printf("Successfully deleted %d posts\n", deletedCount)
return nil
}
func pruneUsers(userRepo repositories.UserRepository, postRepo repositories.PostRepository, args []string) error {
fs := flag.NewFlagSet("prune users", flag.ContinueOnError)
dryRun := fs.Bool("dry-run", false, "preview what would be deleted without actually deleting")
deletePosts := fs.Bool("with-posts", false, "also delete all posts when deleting users")
fs.SetOutput(os.Stderr)
if err := fs.Parse(args); err != nil {
return err
}
users, err := userRepo.GetAll(0, 0)
if err != nil {
return fmt.Errorf("get users: %w", err)
}
userCount := len(users)
if userCount == 0 {
fmt.Println("No users found to delete")
return nil
}
var postCount int64 = 0
if *deletePosts {
postCount, err = postRepo.Count()
if err != nil {
return fmt.Errorf("get post count: %w", err)
}
}
fmt.Printf("Found %d users", userCount)
if *deletePosts {
fmt.Printf(" and %d posts", postCount)
}
fmt.Println(" to delete")
fmt.Println("\nUsers to be deleted:")
for _, user := range users {
fmt.Printf(" ID=%d Username=%s Email=%s\n", user.ID, user.Username, user.Email)
}
if *dryRun {
fmt.Println("\nDry run: No data was actually deleted")
return nil
}
confirmMsg := fmt.Sprintf("\nAre you sure you want to permanently delete %d users", userCount)
if *deletePosts {
confirmMsg += fmt.Sprintf(" and %d posts", postCount)
}
confirmMsg += "? (yes/no): "
fmt.Print(confirmMsg)
var confirmation string
if _, err := fmt.Scanln(&confirmation); err != nil {
return fmt.Errorf("read confirmation: %w", err)
}
if confirmation != "yes" {
fmt.Println("Operation cancelled")
return nil
}
if *deletePosts {
totalDeleted, err := userRepo.HardDeleteAll()
if err != nil {
return fmt.Errorf("hard delete all users and posts: %w", err)
}
fmt.Printf("Successfully deleted %d total records (users, posts, votes, etc.)\n", totalDeleted)
} else {
deletedCount := 0
for _, user := range users {
if err := userRepo.SoftDeleteWithPosts(user.ID); err != nil {
return fmt.Errorf("soft delete user %d: %w", user.ID, err)
}
deletedCount++
}
fmt.Printf("Successfully soft deleted %d users (posts preserved)\n", deletedCount)
}
return nil
}
func pruneAll(userRepo repositories.UserRepository, postRepo repositories.PostRepository, args []string) error {
fs := flag.NewFlagSet("prune all", flag.ContinueOnError)
dryRun := fs.Bool("dry-run", false, "preview what would be deleted without actually deleting")
fs.SetOutput(os.Stderr)
if err := fs.Parse(args); err != nil {
return err
}
userCount, err := userRepo.GetAll(0, 0)
if err != nil {
return fmt.Errorf("get user count: %w", err)
}
postCount, err := postRepo.Count()
if err != nil {
return fmt.Errorf("get post count: %w", err)
}
fmt.Printf("Found %d users and %d posts to delete\n", len(userCount), postCount)
if *dryRun {
fmt.Println("\nDry run: No data was actually deleted")
return nil
}
fmt.Printf("\nAre you sure you want to permanently delete ALL %d users and %d posts? (yes/no): ", len(userCount), postCount)
var confirmation string
if _, err := fmt.Scanln(&confirmation); err != nil {
return fmt.Errorf("read confirmation: %w", err)
}
if confirmation != "yes" {
fmt.Println("Operation cancelled")
return nil
}
totalDeleted, err := userRepo.HardDeleteAll()
if err != nil {
return fmt.Errorf("hard delete all: %w", err)
}
fmt.Printf("Successfully deleted %d total records (users, posts, votes, etc.)\n", totalDeleted)
return nil
}

View File

@@ -0,0 +1,419 @@
package commands
import (
"fmt"
"os"
"strings"
"testing"
"goyco/internal/database"
"goyco/internal/testutils"
)
func TestHandlePruneCommand(t *testing.T) {
tests := []struct {
name string
args []string
wantErr bool
}{
{
name: "help requested",
args: []string{"help"},
wantErr: false,
},
{
name: "missing subcommand",
args: []string{},
wantErr: true,
},
{
name: "unknown subcommand",
args: []string{"unknown"},
wantErr: true,
},
{
name: "posts subcommand",
args: []string{"posts", "--dry-run"},
wantErr: false,
},
{
name: "all subcommand",
args: []string{"all", "--dry-run"},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := testutils.NewTestConfig()
userRepo := testutils.NewMockUserRepository()
postRepo := testutils.NewMockPostRepository()
err := runPruneCommand(cfg, userRepo, postRepo, tt.args)
if (err != nil) != tt.wantErr {
t.Errorf("runPruneCommand() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestRunPruneCommand(t *testing.T) {
tests := []struct {
name string
args []string
wantErr bool
}{
{
name: "help requested",
args: []string{"help"},
wantErr: false,
},
{
name: "missing subcommand",
args: []string{},
wantErr: true,
},
{
name: "unknown subcommand",
args: []string{"unknown"},
wantErr: true,
},
{
name: "posts subcommand",
args: []string{"posts", "--dry-run"},
wantErr: false,
},
{
name: "all subcommand",
args: []string{"all", "--dry-run"},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := testutils.NewTestConfig()
userRepo := testutils.NewMockUserRepository()
postRepo := testutils.NewMockPostRepository()
err := runPruneCommand(cfg, userRepo, postRepo, tt.args)
if (err != nil) != tt.wantErr {
t.Errorf("runPruneCommand() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestPrunePosts(t *testing.T) {
postRepo := testutils.NewMockPostRepository()
err := prunePosts(postRepo, []string{"--dry-run"})
if err != nil {
t.Errorf("prunePosts() with dry-run error = %v", err)
}
post1 := database.Post{
ID: 1,
Title: "Post by deleted user 1",
URL: "http://example.com/1",
AuthorID: nil,
}
post2 := database.Post{
ID: 2,
Title: "Post by deleted user 2",
URL: "http://example.com/2",
AuthorID: nil,
}
postRepo.Posts[post1.ID] = &post1
postRepo.Posts[post2.ID] = &post2
postRepo.GetPostsByDeletedUsersFunc = func() ([]database.Post, error) {
return []database.Post{post1, post2}, nil
}
postRepo.HardDeletePostsByDeletedUsersFunc = func() (int64, error) {
delete(postRepo.Posts, post1.ID)
delete(postRepo.Posts, post2.ID)
return 2, nil
}
err = prunePosts(postRepo, []string{"--dry-run"})
if err != nil {
t.Errorf("prunePosts() with dry-run error = %v", err)
}
}
func TestPruneAll(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
postRepo := testutils.NewMockPostRepository()
err := pruneAll(userRepo, postRepo, []string{"--dry-run"})
if err != nil {
t.Errorf("pruneAll() with dry-run error = %v", err)
}
user1 := database.User{ID: 1, Username: "user1", Email: "user1@example.com"}
user2 := database.User{ID: 2, Username: "user2", Email: "user2@example.com"}
post1 := database.Post{ID: 1, Title: "Post 1", URL: "http://example.com/1", AuthorID: &user1.ID}
post2 := database.Post{ID: 2, Title: "Post 2", URL: "http://example.com/2", AuthorID: &user2.ID}
userRepo.Users[user1.ID] = &user1
userRepo.Users[user2.ID] = &user2
postRepo.Posts[post1.ID] = &post1
postRepo.Posts[post2.ID] = &post2
userRepo.HardDeleteAllFunc = func() (int64, error) {
count := int64(len(userRepo.Users) + len(userRepo.DeletedUsers))
userRepo.Users = make(map[uint]*database.User)
userRepo.DeletedUsers = make(map[uint]*database.User)
return count, nil
}
postRepo.HardDeleteAllFunc = func() (int64, error) {
count := int64(len(postRepo.Posts))
postRepo.Posts = make(map[uint]*database.Post)
return count, nil
}
err = pruneAll(userRepo, postRepo, []string{"--dry-run"})
if err != nil {
t.Errorf("pruneAll() with dry-run error = %v", err)
}
}
func TestPrunePostsWithError(t *testing.T) {
postRepo := testutils.NewMockPostRepository()
postRepo.GetPostsByDeletedUsersFunc = func() ([]database.Post, error) {
return nil, fmt.Errorf("database error")
}
err := prunePosts(postRepo, []string{"--dry-run"})
if err == nil {
t.Errorf("Expected error from GetPostsByDeletedUsers, got nil")
}
if !strings.Contains(err.Error(), "get posts by deleted users") {
t.Errorf("Expected error message to contain 'get posts by deleted users', got: %v", err)
}
}
func TestPruneAllWithUserError(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
postRepo := testutils.NewMockPostRepository()
userRepo.GetAllFunc = func(limit, offset int) ([]database.User, error) {
return nil, fmt.Errorf("user get error")
}
err := pruneAll(userRepo, postRepo, []string{"--dry-run"})
if err == nil {
t.Errorf("Expected error from GetAll, got nil")
}
if !strings.Contains(err.Error(), "get user count") {
t.Errorf("Expected error message to contain 'get user count', got: %v", err)
}
}
func TestPruneAllWithPostError(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
postRepo := testutils.NewMockPostRepository()
postRepo.CountFunc = func() (int64, error) {
return 0, fmt.Errorf("post count error")
}
err := pruneAll(userRepo, postRepo, []string{"--dry-run"})
if err == nil {
t.Errorf("Expected error from Count, got nil")
}
if !strings.Contains(err.Error(), "get post count") {
t.Errorf("Expected error message to contain 'get post count', got: %v", err)
}
}
func TestPrintPruneUsage(t *testing.T) {
printPruneUsage()
}
func TestPruneFlagParsing(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
postRepo := testutils.NewMockPostRepository()
t.Run("prunePosts unknown flag", func(t *testing.T) {
err := prunePosts(postRepo, []string{"--unknown-flag"})
if err == nil {
t.Error("expected error for unknown flag in prunePosts")
}
})
t.Run("prunePosts missing dry-run value (bool)", func(t *testing.T) {
err := prunePosts(postRepo, []string{"--dry-run"})
if err != nil {
t.Errorf("unexpected error for dry-run: %v", err)
}
})
t.Run("pruneUsers unknown flag", func(t *testing.T) {
err := pruneUsers(userRepo, postRepo, []string{"--unknown-flag"})
if err == nil {
t.Error("expected error for unknown flag in pruneUsers")
}
})
t.Run("pruneUsers with-posts as non-bool", func(t *testing.T) {
err := pruneUsers(userRepo, postRepo, []string{"--with-posts", "true"})
if err != nil {
t.Errorf("unexpected error for with-posts: %v", err)
}
})
t.Run("pruneAll unknown flag", func(t *testing.T) {
err := pruneAll(userRepo, postRepo, []string{"--unknown-flag"})
if err == nil {
t.Error("expected error for unknown flag in pruneAll")
}
})
}
func TestPrunePostsWithMockData(t *testing.T) {
postRepo := testutils.NewMockPostRepository()
post1 := database.Post{
ID: 1,
Title: "Test Post 1",
URL: "http://example.com/1",
AuthorID: nil,
}
post2 := database.Post{
ID: 2,
Title: "Test Post 2",
URL: "http://example.com/2",
AuthorID: nil,
}
postRepo.GetPostsByDeletedUsersFunc = func() ([]database.Post, error) {
return []database.Post{post1, post2}, nil
}
err := prunePosts(postRepo, []string{"--dry-run"})
if err != nil {
t.Errorf("prunePosts() with mock data error = %v", err)
}
}
func TestPruneAllWithMockData(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
postRepo := testutils.NewMockPostRepository()
userRepo.HardDeleteAllFunc = func() (int64, error) {
return 5, nil
}
postRepo.HardDeleteAllFunc = func() (int64, error) {
return 10, nil
}
err := pruneAll(userRepo, postRepo, []string{"--dry-run"})
if err != nil {
t.Errorf("pruneAll() with mock data error = %v", err)
}
}
func TestPrunePostsActualDeletion(t *testing.T) {
postRepo := testutils.NewMockPostRepository()
post1 := database.Post{
ID: 1,
Title: "Test Post 1",
URL: "http://example.com/1",
AuthorID: nil,
}
post2 := database.Post{
ID: 2,
Title: "Test Post 2",
URL: "http://example.com/2",
AuthorID: nil,
}
postRepo.GetPostsByDeletedUsersFunc = func() ([]database.Post, error) {
return []database.Post{post1, post2}, nil
}
var deletedCount int64
postRepo.HardDeletePostsByDeletedUsersFunc = func() (int64, error) {
deletedCount = 2
return 2, nil
}
originalStdin := os.Stdin
defer func() { os.Stdin = originalStdin }()
r, w, err := os.Pipe()
if err != nil {
t.Fatalf("Failed to create pipe: %v", err)
}
defer func() { _ = r.Close() }()
defer func() { _ = w.Close() }()
os.Stdin = r
go func() {
_, _ = w.WriteString("yes\n")
_ = w.Close()
}()
err = prunePosts(postRepo, []string{})
if err != nil {
t.Errorf("prunePosts() actual deletion error = %v", err)
}
if deletedCount != 2 {
t.Errorf("Expected 2 posts to be deleted, got %d", deletedCount)
}
}
func TestPruneAllActualDeletion(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
postRepo := testutils.NewMockPostRepository()
user1 := database.User{ID: 1, Username: "user1", Email: "user1@example.com"}
user2 := database.User{ID: 2, Username: "user2", Email: "user2@example.com"}
post1 := database.Post{ID: 1, Title: "Post 1", URL: "http://example.com/1", AuthorID: &user1.ID}
post2 := database.Post{ID: 2, Title: "Post 2", URL: "http://example.com/2", AuthorID: &user2.ID}
userRepo.Users[user1.ID] = &user1
userRepo.Users[user2.ID] = &user2
postRepo.Posts[post1.ID] = &post1
postRepo.Posts[post2.ID] = &post2
var totalDeleted int64
userRepo.HardDeleteAllFunc = func() (int64, error) {
totalDeleted = 2
return 2, nil
}
originalStdin := os.Stdin
defer func() { os.Stdin = originalStdin }()
reader, writer, err := os.Pipe()
if err != nil {
t.Fatalf("Failed to create pipe: %v", err)
}
defer func() { _ = reader.Close() }()
defer func() { _ = writer.Close() }()
os.Stdin = reader
go func() {
_, _ = writer.WriteString("yes\n")
_ = writer.Close()
}()
err = pruneAll(userRepo, postRepo, []string{})
if err != nil {
t.Errorf("pruneAll() actual deletion error = %v", err)
}
if totalDeleted != 2 {
t.Errorf("Expected 2 users to be deleted, got %d", totalDeleted)
}
}

353
cmd/goyco/commands/seed.go Normal file
View File

@@ -0,0 +1,353 @@
package commands
import (
"crypto/rand"
"errors"
"flag"
"fmt"
"math/big"
"os"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"goyco/internal/config"
"goyco/internal/database"
"goyco/internal/repositories"
)
func HandleSeedCommand(cfg *config.Config, name string, args []string) error {
fs := newFlagSet(name, printSeedUsage)
if err := parseCommand(fs, args, name); err != nil {
if errors.Is(err, ErrHelpRequested) {
return nil
}
return err
}
return withDatabase(cfg, func(db *gorm.DB) error {
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
voteRepo := repositories.NewVoteRepository(db)
return runSeedCommand(userRepo, postRepo, voteRepo, fs.Args())
})
}
func runSeedCommand(userRepo repositories.UserRepository, postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, args []string) error {
if len(args) == 0 {
printSeedUsage()
return errors.New("missing seed subcommand")
}
switch args[0] {
case "database":
return seedDatabase(userRepo, postRepo, voteRepo, args[1:])
case "help", "-h", "--help":
printSeedUsage()
return nil
default:
printSeedUsage()
return fmt.Errorf("unknown seed subcommand: %s", args[0])
}
}
func printSeedUsage() {
fmt.Fprintln(os.Stderr, "Seed subcommands:")
fmt.Fprintln(os.Stderr, " database [--posts <n>] [--users <n>] [--votes-per-post <n>]")
fmt.Fprintln(os.Stderr, " --posts: number of posts to create (default: 40)")
fmt.Fprintln(os.Stderr, " --users: number of additional users to create (default: 5)")
fmt.Fprintln(os.Stderr, " --votes-per-post: average votes per post (default: 15)")
}
func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, args []string) error {
fs := flag.NewFlagSet("seed database", flag.ContinueOnError)
numPosts := fs.Int("posts", 40, "number of posts to create")
numUsers := fs.Int("users", 5, "number of additional users to create")
votesPerPost := fs.Int("votes-per-post", 15, "average votes per post")
fs.SetOutput(os.Stderr)
if err := fs.Parse(args); err != nil {
return err
}
fmt.Println("Starting database seeding...")
spinner := NewSpinner("Creating seed user")
spinner.Spin()
seedUser, err := ensureSeedUser(userRepo)
if err != nil {
spinner.Complete()
return fmt.Errorf("ensure seed user: %w", err)
}
spinner.Complete()
fmt.Printf("Seed user ready: ID=%d Username=%s\n", seedUser.ID, seedUser.Username)
processor := NewParallelProcessor()
progress := NewProgressIndicator(*numUsers, "Creating users (parallel)")
users, err := processor.CreateUsersInParallel(userRepo, *numUsers, progress)
if err != nil {
return fmt.Errorf("create random users: %w", err)
}
progress.Complete()
allUsers := append([]database.User{*seedUser}, users...)
progress = NewProgressIndicator(*numPosts, "Creating posts (parallel)")
posts, err := processor.CreatePostsInParallel(postRepo, seedUser.ID, *numPosts, progress)
if err != nil {
return fmt.Errorf("create random posts: %w", err)
}
progress.Complete()
progress = NewProgressIndicator(len(posts), "Creating votes (parallel)")
votes, err := processor.CreateVotesInParallel(voteRepo, allUsers, posts, *votesPerPost, progress)
if err != nil {
return fmt.Errorf("create random votes: %w", err)
}
progress.Complete()
progress = NewProgressIndicator(len(posts), "Updating scores (parallel)")
err = processor.UpdatePostScoresInParallel(postRepo, voteRepo, posts, progress)
if err != nil {
return fmt.Errorf("update post scores: %w", err)
}
progress.Complete()
fmt.Println("Database seeding completed successfully!")
fmt.Printf("Created %d users, %d posts, and %d votes\n", len(allUsers), len(posts), votes)
return nil
}
func ensureSeedUser(userRepo repositories.UserRepository) (*database.User, error) {
seedUsername := "seed_admin"
seedEmail := "seed_admin@goyco.local"
seedPassword := "seed-password"
user, err := userRepo.GetByEmail(seedEmail)
if err == nil {
return user, nil
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(seedPassword), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("hash password: %w", err)
}
user = &database.User{
Username: seedUsername,
Email: seedEmail,
Password: string(hashedPassword),
EmailVerified: true,
}
if err := userRepo.Create(user); err != nil {
return nil, fmt.Errorf("create seed user: %w", err)
}
return user, nil
}
func createRandomUsers(userRepo repositories.UserRepository, count int) ([]database.User, error) {
var users []database.User
for i := range count {
username := fmt.Sprintf("user_%d", i+1)
email := fmt.Sprintf("user_%d@goyco.local", i+1)
password := "password123"
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("hash password for user %d: %w", i+1, err)
}
user := &database.User{
Username: username,
Email: email,
Password: string(hashedPassword),
EmailVerified: true,
}
if err := userRepo.Create(user); err != nil {
return nil, fmt.Errorf("create user %d: %w", i+1, err)
}
users = append(users, *user)
}
return users, nil
}
func createRandomPosts(postRepo repositories.PostRepository, authorID uint, count int) ([]database.Post, error) {
var posts []database.Post
sampleTitles := []string{
"Amazing JavaScript Framework",
"Python Best Practices",
"Go Performance Tips",
"Database Optimization",
"Web Security Guide",
"Machine Learning Basics",
"Cloud Architecture",
"DevOps Automation",
"API Design Patterns",
"Frontend Optimization",
"Backend Scaling",
"Container Orchestration",
"Microservices Architecture",
"Testing Strategies",
"Code Review Process",
"Version Control Best Practices",
"Continuous Integration",
"Monitoring and Alerting",
"Error Handling Patterns",
"Data Structures Explained",
}
sampleDomains := []string{
"example.com",
"techblog.org",
"devguide.net",
"programming.io",
"codeexamples.com",
"tutorialhub.org",
"bestpractices.dev",
"learnprogramming.net",
"codingtips.org",
"softwareengineering.com",
}
for i := range count {
title := sampleTitles[i%len(sampleTitles)]
if i >= len(sampleTitles) {
title = fmt.Sprintf("%s - Part %d", title, (i/len(sampleTitles))+1)
}
domain := sampleDomains[i%len(sampleDomains)]
path := generateRandomPath()
url := fmt.Sprintf("https://%s%s", domain, path)
content := fmt.Sprintf("Autogenerated seed post #%d\n\nThis is sample content for testing purposes. The post discusses %s and provides valuable insights.", i+1, title)
post := &database.Post{
Title: title,
URL: url,
Content: content,
AuthorID: &authorID,
UpVotes: 0,
DownVotes: 0,
Score: 0,
}
if err := postRepo.Create(post); err != nil {
return nil, fmt.Errorf("create post %d: %w", i+1, err)
}
posts = append(posts, *post)
}
return posts, nil
}
func generateRandomPath() string {
pathLength, _ := rand.Int(rand.Reader, big.NewInt(20))
path := "/article/"
for i := int64(0); i < pathLength.Int64()+5; i++ {
randomChar, _ := rand.Int(rand.Reader, big.NewInt(26))
path += string(rune('a' + randomChar.Int64()))
}
return path
}
func createRandomVotes(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post, avgVotesPerPost int) (int, error) {
totalVotes := 0
for _, post := range posts {
voteCount, _ := rand.Int(rand.Reader, big.NewInt(int64(avgVotesPerPost*2)+1))
numVotes := int(voteCount.Int64())
if numVotes == 0 && avgVotesPerPost > 0 {
chance, _ := rand.Int(rand.Reader, big.NewInt(5))
if chance.Int64() > 0 {
numVotes = 1
}
}
usedUsers := make(map[uint]bool)
for i := 0; i < numVotes && len(usedUsers) < len(users); i++ {
userIdx, _ := rand.Int(rand.Reader, big.NewInt(int64(len(users))))
user := users[userIdx.Int64()]
if usedUsers[user.ID] {
continue
}
usedUsers[user.ID] = true
voteTypeInt, _ := rand.Int(rand.Reader, big.NewInt(10))
var voteType database.VoteType
if voteTypeInt.Int64() < 7 {
voteType = database.VoteUp
} else {
voteType = database.VoteDown
}
vote := &database.Vote{
UserID: &user.ID,
PostID: post.ID,
Type: voteType,
}
if err := voteRepo.Create(vote); err != nil {
return totalVotes, fmt.Errorf("create vote for post %d: %w", post.ID, err)
}
totalVotes++
}
}
return totalVotes, nil
}
func updatePostScores(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post) error {
for _, post := range posts {
upVotes, downVotes, err := getVoteCounts(voteRepo, post.ID)
if err != nil {
return fmt.Errorf("get vote counts for post %d: %w", post.ID, err)
}
post.UpVotes = upVotes
post.DownVotes = downVotes
post.Score = upVotes - downVotes
if err := postRepo.Update(&post); err != nil {
return fmt.Errorf("update post %d scores: %w", post.ID, err)
}
}
return nil
}
func getVoteCounts(voteRepo repositories.VoteRepository, postID uint) (int, int, error) {
votes, err := voteRepo.GetByPostID(postID)
if err != nil {
return 0, 0, err
}
upVotes := 0
downVotes := 0
for _, vote := range votes {
switch vote.Type {
case database.VoteUp:
upVotes++
case database.VoteDown:
downVotes++
}
}
return upVotes, downVotes, nil
}

View File

@@ -0,0 +1,181 @@
package commands
import (
"testing"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"goyco/internal/database"
"goyco/internal/repositories"
"goyco/internal/testutils"
)
func TestSeedCommand(t *testing.T) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("Failed to connect to database: %v", err)
}
err = db.AutoMigrate(&database.User{}, &database.Post{}, &database.Vote{})
if err != nil {
t.Fatalf("Failed to migrate database: %v", err)
}
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
voteRepo := repositories.NewVoteRepository(db)
seedUser, err := ensureSeedUser(userRepo)
if err != nil {
t.Fatalf("Failed to ensure seed user: %v", err)
}
if seedUser.Username != "seed_admin" {
t.Errorf("Expected username 'seed_admin', got '%s'", seedUser.Username)
}
if seedUser.Email != "seed_admin@goyco.local" {
t.Errorf("Expected email 'seed_admin@goyco.local', got '%s'", seedUser.Email)
}
if !seedUser.EmailVerified {
t.Error("Expected seed user to be email verified")
}
users, err := createRandomUsers(userRepo, 2)
if err != nil {
t.Fatalf("Failed to create random users: %v", err)
}
if len(users) != 2 {
t.Errorf("Expected 2 users, got %d", len(users))
}
posts, err := createRandomPosts(postRepo, seedUser.ID, 5)
if err != nil {
t.Fatalf("Failed to create random posts: %v", err)
}
if len(posts) != 5 {
t.Errorf("Expected 5 posts, got %d", len(posts))
}
for i, post := range posts {
if post.Title == "" {
t.Errorf("Post %d has empty title", i)
}
if post.URL == "" {
t.Errorf("Post %d has empty URL", i)
}
if post.AuthorID == nil || *post.AuthorID != seedUser.ID {
t.Errorf("Post %d has wrong author ID: expected %d, got %v", i, seedUser.ID, post.AuthorID)
}
}
allUsers := append([]database.User{*seedUser}, users...)
votes, err := createRandomVotes(voteRepo, allUsers, posts, 3)
if err != nil {
t.Fatalf("Failed to create random votes: %v", err)
}
if votes == 0 {
t.Error("Expected some votes to be created")
}
err = updatePostScores(postRepo, voteRepo, posts)
if err != nil {
t.Fatalf("Failed to update post scores: %v", err)
}
for i, post := range posts {
updatedPost, err := postRepo.GetByID(post.ID)
if err != nil {
t.Errorf("Failed to get updated post %d: %v", i, err)
continue
}
expectedScore := updatedPost.UpVotes - updatedPost.DownVotes
if updatedPost.Score != expectedScore {
t.Errorf("Post %d has incorrect score: expected %d, got %d", i, expectedScore, updatedPost.Score)
}
}
}
func TestGenerateRandomPath(t *testing.T) {
path := generateRandomPath()
if path == "" {
t.Error("Generated path should not be empty")
}
if len(path) < 8 {
t.Errorf("Generated path too short: %s", path)
}
secondPath := generateRandomPath()
if path == secondPath {
t.Error("Generated paths should be different")
}
}
func TestSeedDatabaseFlagParsing(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
postRepo := testutils.NewMockPostRepository()
voteRepo := testutils.NewMockVoteRepository()
t.Run("invalid posts type", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--posts", "abc"})
if err == nil {
t.Error("expected error for invalid posts type")
}
})
t.Run("invalid users type", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--users", "xyz"})
if err == nil {
t.Error("expected error for invalid users type")
}
})
t.Run("invalid votes-per-post type", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--votes-per-post", "invalid"})
if err == nil {
t.Error("expected error for invalid votes-per-post type")
}
})
t.Run("unknown flag", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--unknown-flag"})
if err == nil {
t.Error("expected error for unknown flag")
}
})
t.Run("missing posts value", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--posts"})
if err == nil {
t.Error("expected error for missing posts value")
}
})
t.Run("missing users value", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--users"})
if err == nil {
t.Error("expected error for missing users value")
}
})
t.Run("missing votes-per-post value", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--votes-per-post"})
if err == nil {
t.Error("expected error for missing votes-per-post value")
}
})
}

907
cmd/goyco/commands/user.go Normal file
View File

@@ -0,0 +1,907 @@
package commands
import (
"crypto/rand"
"errors"
"flag"
"fmt"
"math/big"
"os"
"strconv"
"strings"
"time"
"github.com/lib/pq"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"goyco/internal/config"
"goyco/internal/database"
"goyco/internal/repositories"
"goyco/internal/security"
"goyco/internal/services"
)
func HandleUserCommand(cfg *config.Config, name string, args []string) error {
fs := newFlagSet(name, printUserUsage)
if err := parseCommand(fs, args, name); err != nil {
if errors.Is(err, ErrHelpRequested) {
return nil
}
return err
}
return withDatabase(cfg, func(db *gorm.DB) error {
repo := repositories.NewUserRepository(db)
refreshTokenRepo := repositories.NewRefreshTokenRepository(db)
return runUserCommand(cfg, repo, refreshTokenRepo, fs.Args())
})
}
func runUserCommand(cfg *config.Config, repo repositories.UserRepository, refreshTokenRepo repositories.RefreshTokenRepositoryInterface, args []string) error {
if len(args) == 0 {
printUserUsage()
return errors.New("missing user subcommand")
}
switch args[0] {
case "create":
return userCreate(cfg, repo, args[1:])
case "update":
return userUpdate(cfg, repo, refreshTokenRepo, args[1:])
case "delete":
return userDelete(cfg, repo, args[1:])
case "lock":
return userLock(cfg, repo, args[1:])
case "unlock":
return userUnlock(cfg, repo, args[1:])
case "list":
return userList(repo, args[1:])
case "help", "-h", "--help":
printUserUsage()
return nil
default:
printUserUsage()
return fmt.Errorf("unknown user subcommand: %s", args[0])
}
}
func printUserUsage() {
fmt.Fprintln(os.Stderr, "User subcommands:")
fmt.Fprintln(os.Stderr, " create --username <name> --email <email> --password <password>")
fmt.Fprintln(os.Stderr, " update <id> [--username <name>] [--email <email>] [--password <password>] [--reset-password]")
fmt.Fprintln(os.Stderr, " delete <id> [--with-posts]")
fmt.Fprintln(os.Stderr, " lock <id>")
fmt.Fprintln(os.Stderr, " unlock <id>")
fmt.Fprintln(os.Stderr, " list [--limit <n>] [--offset <n>]")
}
func createSessionService(cfg *config.Config, userRepo repositories.UserRepository, refreshTokenRepo repositories.RefreshTokenRepositoryInterface) *services.SessionService {
jwtService := services.NewJWTService(&cfg.JWT, userRepo, refreshTokenRepo)
return services.NewSessionService(jwtService, userRepo)
}
func userCreate(cfg *config.Config, repo repositories.UserRepository, args []string) error {
fs := flag.NewFlagSet("user create", flag.ContinueOnError)
username := fs.String("username", "", "username")
email := fs.String("email", "", "email")
password := fs.String("password", "", "password")
fs.SetOutput(os.Stderr)
if err := fs.Parse(args); err != nil {
return err
}
if *username == "" || *email == "" || *password == "" {
fs.Usage()
return errors.New("username, email, and password are required")
}
auditLogger, err := NewAuditLogger(cfg.LogDir)
if err != nil {
fmt.Printf("Warning: Could not initialize audit logging: %v\n", err)
auditLogger = nil
}
sanitizer := security.NewInputSanitizer()
sanitizedUsername, err := sanitizer.SanitizeUsernameCLI(*username)
if err != nil {
if auditLogger != nil {
auditLogger.LogUserCreation(0, *username, *email, false, err)
}
return fmt.Errorf("username validation: %w", err)
}
sanitizedEmail, err := sanitizer.SanitizeEmailCLI(*email)
if err != nil {
if auditLogger != nil {
auditLogger.LogUserCreation(0, sanitizedUsername, *email, false, err)
}
return fmt.Errorf("email validation: %w", err)
}
if err := sanitizer.SanitizePasswordCLI(*password); err != nil {
if auditLogger != nil {
auditLogger.LogUserCreation(0, sanitizedUsername, sanitizedEmail, false, err)
}
return fmt.Errorf("password validation: %w", err)
}
_, err = repo.GetByUsername(sanitizedUsername)
if err == nil {
return fmt.Errorf("username %s already exists", sanitizedUsername)
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("check username: %w", err)
}
_, err = repo.GetByEmail(sanitizedEmail)
if err == nil {
return fmt.Errorf("email %s already exists", sanitizedEmail)
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("check email: %w", err)
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
now := time.Now()
user := &database.User{
Username: sanitizedUsername,
Email: sanitizedEmail,
Password: string(hashedPassword),
EmailVerified: true,
EmailVerifiedAt: &now,
}
if err := repo.Create(user); err != nil {
if auditLogger != nil {
auditLogger.LogUserCreation(0, sanitizedUsername, sanitizedEmail, false, err)
}
return handleDatabaseConstraintError(err)
}
if auditLogger != nil {
auditLogger.LogUserCreation(user.ID, user.Username, user.Email, true, nil)
}
fmt.Printf("User created: %s (%s)\n", user.Username, user.Email)
return nil
}
func userUpdate(cfg *config.Config, repo repositories.UserRepository, refreshTokenRepo repositories.RefreshTokenRepositoryInterface, args []string) error {
if len(args) == 0 {
return errors.New("user ID is required")
}
idStr := args[0]
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
return fmt.Errorf("invalid user ID: %s", idStr)
}
if id == 0 {
return errors.New("user ID must be greater than 0")
}
fs := flag.NewFlagSet("user update", flag.ContinueOnError)
username := fs.String("username", "", "new username")
email := fs.String("email", "", "new email")
password := fs.String("password", "", "new password")
resetPassword := fs.Bool("reset-password", false, "reset password and send temporary password via email")
fs.SetOutput(os.Stderr)
fs.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage of user update:\n")
fmt.Fprintf(os.Stderr, " --email string\n")
fmt.Fprintf(os.Stderr, " new email\n")
fmt.Fprintf(os.Stderr, " --password string\n")
fmt.Fprintf(os.Stderr, " new password\n")
fmt.Fprintf(os.Stderr, " --reset-password\n")
fmt.Fprintf(os.Stderr, " reset password and send temporary password via email\n")
fmt.Fprintf(os.Stderr, " --username string\n")
fmt.Fprintf(os.Stderr, " new username\n")
}
if err := fs.Parse(args[1:]); err != nil {
return err
}
if *username == "" && *email == "" && *password == "" && !*resetPassword {
fs.Usage()
return errors.New("no update options provided")
}
sanitizer := security.NewInputSanitizer()
if *username != "" {
sanitizedUsername, err := sanitizer.SanitizeUsernameCLI(*username)
if err != nil {
return fmt.Errorf("username validation: %w", err)
}
*username = sanitizedUsername
}
if *email != "" {
sanitizedEmail, err := sanitizer.SanitizeEmailCLI(*email)
if err != nil {
return fmt.Errorf("email validation: %w", err)
}
*email = sanitizedEmail
}
if *password != "" {
if err := sanitizer.SanitizePasswordCLI(*password); err != nil {
return fmt.Errorf("password validation: %w", err)
}
}
if *resetPassword {
sessionService := createSessionService(cfg, repo, refreshTokenRepo)
return resetUserPassword(cfg, repo, sessionService, uint(id))
}
user, err := repo.GetByID(uint(id))
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("user %d not found", id)
}
return fmt.Errorf("fetch user: %w", err)
}
if *username != "" && *username != user.Username {
if err := checkUsernameAvailable(repo, *username, uint(id)); err != nil {
return err
}
user.Username = *username
}
if *email != "" && *email != user.Email {
if err := checkEmailAvailable(repo, *email, uint(id)); err != nil {
return err
}
user.Email = *email
}
if *password != "" {
if len(*password) < 8 {
return errors.New("password must be at least 8 characters")
}
hashedPassword, hashErr := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
if hashErr != nil {
return fmt.Errorf("hash password: %w", hashErr)
}
user.Password = string(hashedPassword)
sessionService := createSessionService(cfg, repo, refreshTokenRepo)
if err := sessionService.InvalidateAllSessions(user.ID); err != nil {
return fmt.Errorf("invalidate sessions: %w", err)
}
}
if err := repo.Update(user); err != nil {
return handleDatabaseConstraintError(err)
}
fmt.Printf("User updated: %s (%s)\n", user.Username, user.Email)
return nil
}
func checkUsernameAvailable(repo repositories.UserRepository, username string, excludeID uint) error {
existing, err := repo.GetByUsernameIncludingDeleted(username)
if err == nil && existing.ID != excludeID {
return fmt.Errorf("username %s is already taken", username)
}
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("check username availability: %w", err)
}
return nil
}
func checkEmailAvailable(repo repositories.UserRepository, email string, excludeID uint) error {
existing, err := repo.GetByEmail(email)
if err == nil && existing.ID != excludeID {
return fmt.Errorf("email %s is already registered", email)
}
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("check email availability: %w", err)
}
return nil
}
func handleDatabaseConstraintError(err error) error {
var pqErr *pq.Error
if errors.As(err, &pqErr) && pqErr.Code == "23505" {
if strings.Contains(pqErr.Constraint, "username") {
return fmt.Errorf("username is already taken")
}
if strings.Contains(pqErr.Constraint, "email") {
return fmt.Errorf("email is already registered")
}
return fmt.Errorf("data already exists (constraint violation)")
}
return fmt.Errorf("update user: %w", err)
}
func userDelete(cfg *config.Config, repo repositories.UserRepository, args []string) error {
var userID string
var flagArgs []string
for _, arg := range args {
if strings.HasPrefix(arg, "-") {
flagArgs = append(flagArgs, arg)
} else if userID == "" {
userID = arg
} else {
flagArgs = append(flagArgs, arg)
}
}
fs := flag.NewFlagSet("user delete", flag.ContinueOnError)
deletePosts := fs.Bool("with-posts", false, "also delete user's posts (default: keep posts)")
fs.SetOutput(os.Stderr)
fs.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage of user delete:\n")
fmt.Fprintf(os.Stderr, " --with-posts\n")
fmt.Fprintf(os.Stderr, " also delete user's posts (default: keep posts)\n")
}
if err := fs.Parse(flagArgs); err != nil {
return err
}
if userID == "" {
fs.Usage()
return errors.New("user ID is required")
}
idStr := userID
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
return fmt.Errorf("invalid user ID: %s", idStr)
}
if id == 0 {
return errors.New("user ID must be greater than 0")
}
user, err := repo.GetByID(uint(id))
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
_, deletedErr := repo.GetByIDIncludingDeleted(uint(id))
if deletedErr == nil {
return fmt.Errorf("user with ID %d is already deleted", id)
}
return fmt.Errorf("user with ID %d not found", id)
}
return fmt.Errorf("get user: %w", err)
}
var deleteErr error
if *deletePosts {
deleteErr = repo.HardDelete(uint(id))
if deleteErr == nil {
fmt.Printf("User deleted: ID=%d (posts also deleted)\n", id)
}
} else {
deleteErr = repo.SoftDeleteWithPosts(uint(id))
if deleteErr == nil {
fmt.Printf("User deleted: ID=%d (posts kept)\n", id)
}
}
if deleteErr != nil {
return fmt.Errorf("delete user: %w", deleteErr)
}
emailSender := services.NewSMTPSenderWithTimeout(cfg.SMTP.Host, cfg.SMTP.Port, cfg.SMTP.Username, cfg.SMTP.Password, cfg.SMTP.From, cfg.SMTP.Timeout)
subject, body := services.GenerateAdminAccountDeletionNotificationEmail(user.Username, cfg.App.AdminEmail, cfg.App.BaseURL, cfg.App.Title, *deletePosts)
if err := emailSender.Send(user.Email, subject, body); err != nil {
fmt.Printf("Warning: Could not send notification email to %s: %v\n", user.Email, err)
} else {
fmt.Printf("Notification email sent to %s\n", user.Email)
}
return nil
}
func userList(repo repositories.UserRepository, args []string) error {
fs := flag.NewFlagSet("user list", flag.ContinueOnError)
limit := fs.Int("limit", 0, "max number of users to list")
offset := fs.Int("offset", 0, "number of users to skip")
fs.SetOutput(os.Stderr)
if err := fs.Parse(args); err != nil {
return err
}
users, err := repo.GetAll(*limit, *offset)
if err != nil {
return fmt.Errorf("list users: %w", err)
}
if len(users) == 0 {
fmt.Println("No users found")
return nil
}
maxIDWidth := 2
maxUsernameWidth := 8
maxEmailWidth := 5
maxLockedWidth := 6
maxCreatedAtWidth := 10
for _, u := range users {
lockedStatus := "No"
if u.Locked {
lockedStatus = "Yes"
}
createdAtStr := u.CreatedAt.Format("2006-01-02 15:04:05")
if len(fmt.Sprintf("%d", u.ID)) > maxIDWidth {
maxIDWidth = len(fmt.Sprintf("%d", u.ID))
}
if len(u.Username) > maxUsernameWidth {
maxUsernameWidth = len(u.Username)
}
if len(u.Email) > maxEmailWidth {
maxEmailWidth = len(u.Email)
}
if len(lockedStatus) > maxLockedWidth {
maxLockedWidth = len(lockedStatus)
}
if len(createdAtStr) > maxCreatedAtWidth {
maxCreatedAtWidth = len(createdAtStr)
}
}
fmt.Printf("%-*s %-*s %-*s %-*s %s\n",
maxIDWidth, "ID",
maxUsernameWidth, "Username",
maxEmailWidth, "Email",
maxLockedWidth, "Locked",
"CreatedAt")
for _, u := range users {
lockedStatus := "No"
if u.Locked {
lockedStatus = "Yes"
}
createdAtStr := u.CreatedAt.Format("2006-01-02 15:04:05")
fmt.Printf("%-*d %-*s %-*s %-*s %s\n",
maxIDWidth, u.ID,
maxUsernameWidth, u.Username,
maxEmailWidth, u.Email,
maxLockedWidth, lockedStatus,
createdAtStr)
}
return nil
}
func userLock(cfg *config.Config, repo repositories.UserRepository, args []string) error {
fs := flag.NewFlagSet("user lock", flag.ContinueOnError)
fs.SetOutput(os.Stderr)
if err := fs.Parse(args); err != nil {
return err
}
if fs.NArg() == 0 {
fs.Usage()
return errors.New("user ID is required")
}
idStr := fs.Arg(0)
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
return fmt.Errorf("invalid user ID: %s", idStr)
}
if id == 0 {
return errors.New("user ID must be greater than 0")
}
user, err := repo.GetByID(uint(id))
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("user with ID %d not found", id)
}
return fmt.Errorf("get user: %w", err)
}
if user.Locked {
fmt.Printf("User is already locked: %s\n", user.Username)
return nil
}
if err := repo.Lock(uint(id)); err != nil {
return fmt.Errorf("lock user: %w", err)
}
fmt.Printf("User locked: %s\n", user.Username)
emailSender := services.NewSMTPSenderWithTimeout(cfg.SMTP.Host, cfg.SMTP.Port, cfg.SMTP.Username, cfg.SMTP.Password, cfg.SMTP.From, cfg.SMTP.Timeout)
subject, body := services.GenerateAccountLockNotificationEmail(user.Username, cfg.App.AdminEmail, cfg.App.Title)
if err := emailSender.Send(user.Email, subject, body); err != nil {
fmt.Printf("Warning: Could not send notification email to %s: %v\n", user.Email, err)
} else {
fmt.Printf("Notification email sent to %s\n", user.Email)
}
return nil
}
func userUnlock(cfg *config.Config, repo repositories.UserRepository, args []string) error {
fs := flag.NewFlagSet("user unlock", flag.ContinueOnError)
fs.SetOutput(os.Stderr)
if err := fs.Parse(args); err != nil {
return err
}
if fs.NArg() == 0 {
fs.Usage()
return errors.New("user ID is required")
}
idStr := fs.Arg(0)
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
return fmt.Errorf("invalid user ID: %s", idStr)
}
if id == 0 {
return errors.New("user ID must be greater than 0")
}
user, err := repo.GetByID(uint(id))
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("user with ID %d not found", id)
}
return fmt.Errorf("get user: %w", err)
}
if !user.Locked {
fmt.Printf("User is already unlocked: %s\n", user.Username)
return nil
}
if err := repo.Unlock(uint(id)); err != nil {
return fmt.Errorf("unlock user: %w", err)
}
fmt.Printf("User unlocked: %s\n", user.Username)
emailSender := services.NewSMTPSenderWithTimeout(cfg.SMTP.Host, cfg.SMTP.Port, cfg.SMTP.Username, cfg.SMTP.Password, cfg.SMTP.From, cfg.SMTP.Timeout)
subject, body := services.GenerateAccountUnlockNotificationEmail(user.Username, cfg.App.AdminEmail, cfg.App.BaseURL, cfg.App.Title)
if err := emailSender.Send(user.Email, subject, body); err != nil {
fmt.Printf("Warning: Could not send notification email to %s: %v\n", user.Email, err)
} else {
fmt.Printf("Notification email sent to %s\n", user.Email)
}
return nil
}
func resetUserPassword(cfg *config.Config, repo repositories.UserRepository, sessionService *services.SessionService, userID uint) error {
user, err := repo.GetByID(userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("user %d not found", userID)
}
return fmt.Errorf("fetch user: %w", err)
}
tempPassword, err := generateTemporaryPassword()
if err != nil {
return fmt.Errorf("generate temporary password: %w", err)
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(tempPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
user.Password = string(hashedPassword)
if err := repo.Update(user); err != nil {
return fmt.Errorf("update password: %w", err)
}
if err := sessionService.InvalidateAllSessions(userID); err != nil {
return fmt.Errorf("invalidate sessions: %w", err)
}
emailSender := services.NewSMTPSenderWithTimeout(cfg.SMTP.Host, cfg.SMTP.Port, cfg.SMTP.Username, cfg.SMTP.Password, cfg.SMTP.From, cfg.SMTP.Timeout)
subject := fmt.Sprintf("Password Reset - %s", cfg.App.Title)
body := generatePasswordResetEmailBody(user.Username, tempPassword, cfg.App.BaseURL, cfg.App.AdminEmail, cfg.App.Title)
if err := emailSender.Send(user.Email, subject, body); err != nil {
return fmt.Errorf("send password reset email: %w", err)
}
fmt.Printf("Password reset for user %s: Temporary password sent to %s\n", user.Username, user.Email)
fmt.Printf("⚠️ User must change this password on next login!\n")
return nil
}
func generateTemporaryPassword() (string, error) {
const (
length = 16
chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!@#$%^&*"
)
password := make([]byte, length)
for i := range password {
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(chars))))
if err != nil {
return "", err
}
password[i] = chars[num.Int64()]
}
passwordStr := string(password)
hasUpper := false
hasLower := false
hasDigit := false
hasSpecial := false
for _, char := range passwordStr {
switch {
case char >= 'A' && char <= 'Z':
hasUpper = true
case char >= 'a' && char <= 'z':
hasLower = true
case char >= '0' && char <= '9':
hasDigit = true
case strings.ContainsRune("!@#$%^&*", char):
hasSpecial = true
}
}
passwordBytes := []byte(passwordStr)
if !hasUpper {
passwordBytes[0] = 'A'
}
if !hasLower {
passwordBytes[1] = 'a'
}
if !hasDigit {
passwordBytes[2] = '1'
}
if !hasSpecial {
passwordBytes[3] = '!'
}
hasUpper = false
hasLower = false
hasDigit = false
hasSpecial = false
for _, char := range passwordBytes {
switch {
case char >= 'A' && char <= 'Z':
hasUpper = true
case char >= 'a' && char <= 'z':
hasLower = true
case char >= '0' && char <= '9':
hasDigit = true
case char == '!' || char == '@' || char == '#' || char == '$' || char == '%' || char == '^' || char == '&' || char == '*':
hasSpecial = true
}
}
if !hasUpper {
passwordBytes[4] = 'A'
}
if !hasLower {
passwordBytes[5] = 'a'
}
if !hasDigit {
passwordBytes[6] = '1'
}
if !hasSpecial {
passwordBytes[7] = '!'
}
return string(passwordBytes), nil
}
func generatePasswordResetEmailBody(username, tempPassword, baseURL, adminEmail, siteTitle string) string {
return fmt.Sprintf(`<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Password Reset - %s</title>
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
line-height: 1.6;
color: #333;
max-width: 600px;
margin: 0 auto;
padding: 20px;
background-color: #f8fafc;
}
.email-container {
background: white;
border-radius: 12px;
padding: 40px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
border: 1px solid #e2e8f0;
}
.header {
text-align: center;
margin-bottom: 30px;
}
.logo {
font-size: 28px;
font-weight: 700;
color: #0fb9b1;
margin-bottom: 10px;
}
.title {
font-size: 24px;
font-weight: 600;
color: #1a202c;
margin: 0;
}
.content {
margin-bottom: 30px;
}
.greeting {
font-size: 16px;
margin-bottom: 20px;
color: #2d3748;
}
.message {
font-size: 16px;
margin-bottom: 30px;
color: #4a5568;
white-space: pre-line;
}
.password-box {
background: #f7fafc;
border: 2px solid #e2e8f0;
border-radius: 8px;
padding: 20px;
margin: 20px 0;
text-align: center;
}
.password-label {
font-size: 14px;
color: #718096;
margin-bottom: 10px;
}
.password-value {
font-size: 24px;
font-weight: 700;
color: #2d3748;
font-family: 'Courier New', monospace;
letter-spacing: 2px;
}
.action-button {
display: inline-block;
background: #0fb9b1;
color: white;
text-decoration: none;
padding: 12px 24px;
border-radius: 8px;
font-weight: 600;
font-size: 16px;
text-align: center;
margin: 20px 0;
transition: background-color 0.2s;
}
.action-button:hover {
background: #0ea5a0;
}
.security-notice {
background: #fef5e7;
border: 1px solid #f6ad55;
border-radius: 8px;
padding: 20px;
margin: 20px 0;
}
.security-title {
font-weight: 600;
color: #c05621;
margin-bottom: 10px;
}
.security-list {
margin: 0;
padding-left: 20px;
color: #744210;
}
.footer {
font-size: 14px;
color: #718096;
margin-top: 30px;
padding-top: 20px;
border-top: 1px solid #e2e8f0;
white-space: pre-line;
}
.link {
color: #0fb9b1;
text-decoration: none;
}
.link:hover {
text-decoration: underline;
}
@media (max-width: 600px) {
body {
padding: 10px;
}
.email-container {
padding: 20px;
}
.title {
font-size: 20px;
}
}
</style>
</head>
<body>
<div class="email-container">
<div class="header">
<div class="logo">%s</div>
<h1 class="title">Password Reset - Temporary Password</h1>
</div>
<div class="content">
<div class="greeting">Hello %s,</div>
<div class="message">Your password has been reset by an administrator.
A temporary password has been generated for your account.</div>
<div class="password-box">
<div class="password-label">Your temporary password is:</div>
<div class="password-value">%s</div>
</div>
<div class="security-notice">
<div class="security-title">IMPORTANT SECURITY NOTICE:</div>
<ul class="security-list">
<li>You MUST change this password immediately after logging in</li>
<li>This temporary password will expire and should not be used long-term</li>
<li>Do not share this password with anyone</li>
<li>If you did not request this password reset, contact support immediately</li>
</ul>
</div>
<div style="text-align: center;">
<a href="%s/login" class="action-button">Login to %s</a>
</div>
<div class="message">To change your password:
1. Log in to your account using the temporary password above
2. Go to your account settings
3. Change your password to a new, secure password</div>
</div>
<div class="footer">
If you have any questions or concerns, please <a href="mailto:%s" class="link">contact our support team</a>.<br>
Best regards,<br>
The %s Team
</div>
<div class="powered-by" style="text-align: center; margin-top: 30px; padding-top: 20px; border-top: 1px solid #e2e8f0; font-size: 12px; color: #718096;">
Powered with ❤️ by <a href="https://goyco" style="color: #0fb9b1; text-decoration: none;">Goyco</a>
</div>
</div>
</body>
</html>`, siteTitle, siteTitle, username, tempPassword, baseURL, siteTitle, adminEmail, siteTitle)
}

View File

@@ -0,0 +1,801 @@
package commands
import (
"errors"
"strings"
"testing"
"time"
"goyco/internal/config"
"goyco/internal/database"
"goyco/internal/services"
"goyco/internal/testutils"
)
func TestHandleUserCommand(t *testing.T) {
cfg := testutils.NewTestConfig()
t.Run("help requested", func(t *testing.T) {
err := HandleUserCommand(cfg, "user", []string{"--help"})
if err != nil {
t.Errorf("unexpected error for help: %v", err)
}
})
}
func TestRunUserCommand(t *testing.T) {
cfg := testutils.NewTestConfig()
mockRepo := testutils.NewMockUserRepository()
t.Run("missing subcommand", func(t *testing.T) {
mockRefreshRepo := &mockRefreshTokenRepo{}
err := runUserCommand(cfg, mockRepo, mockRefreshRepo, []string{})
if err == nil {
t.Error("expected error for missing subcommand")
}
if err.Error() != "missing user subcommand" {
t.Errorf("expected specific error, got: %v", err)
}
})
t.Run("unknown subcommand", func(t *testing.T) {
mockRefreshRepo := &mockRefreshTokenRepo{}
err := runUserCommand(cfg, mockRepo, mockRefreshRepo, []string{"unknown"})
if err == nil {
t.Error("expected error for unknown subcommand")
}
expectedErr := "unknown user subcommand: unknown"
if err.Error() != expectedErr {
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
}
})
t.Run("help subcommand", func(t *testing.T) {
mockRefreshRepo := &mockRefreshTokenRepo{}
err := runUserCommand(cfg, mockRepo, mockRefreshRepo, []string{"help"})
if err != nil {
t.Errorf("unexpected error for help: %v", err)
}
})
}
func TestUserCreate(t *testing.T) {
cfg := testutils.NewTestConfig()
t.Run("successful creation", func(t *testing.T) {
mockRepo := testutils.NewMockUserRepository()
err := userCreate(cfg, mockRepo, []string{
"--username", "testuser",
"--email", "test@example.com",
"--password", "StrongPass123!",
})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("missing username", func(t *testing.T) {
mockRepo := testutils.NewMockUserRepository()
err := userCreate(cfg, mockRepo, []string{
"--email", "test@example.com",
"--password", "StrongPass123!",
})
if err == nil {
t.Error("expected error for missing username")
}
if err.Error() != "username, email, and password are required" {
t.Errorf("expected specific error, got: %v", err)
}
})
t.Run("missing email", func(t *testing.T) {
mockRepo := testutils.NewMockUserRepository()
err := userCreate(cfg, mockRepo, []string{
"--username", "testuser",
"--password", "StrongPass123!",
})
if err == nil {
t.Error("expected error for missing email")
}
if err.Error() != "username, email, and password are required" {
t.Errorf("expected specific error, got: %v", err)
}
})
t.Run("missing password", func(t *testing.T) {
mockRepo := testutils.NewMockUserRepository()
err := userCreate(cfg, mockRepo, []string{
"--username", "testuser",
"--email", "test@example.com",
})
if err == nil {
t.Error("expected error for missing password")
}
if err.Error() != "username, email, and password are required" {
t.Errorf("expected specific error, got: %v", err)
}
})
t.Run("password too short", func(t *testing.T) {
mockRepo := testutils.NewMockUserRepository()
err := userCreate(cfg, mockRepo, []string{
"--username", "testuser",
"--email", "test@example.com",
"--password", "short",
})
if err == nil {
t.Error("expected error for short password")
}
if !strings.Contains(err.Error(), "password must be at least 8 characters") {
t.Errorf("expected password length error, got: %v", err)
}
})
t.Run("missing username value", func(t *testing.T) {
mockRepo := testutils.NewMockUserRepository()
err := userCreate(cfg, mockRepo, []string{
"--username",
"--email", "test@example.com",
"--password", "StrongPass123!",
})
if err == nil {
t.Error("expected error for missing username value")
}
})
t.Run("missing email value", func(t *testing.T) {
mockRepo := testutils.NewMockUserRepository()
err := userCreate(cfg, mockRepo, []string{
"--username", "testuser",
"--email",
"--password", "StrongPass123!",
})
if err == nil {
t.Error("expected error for missing email value")
}
})
t.Run("missing password value", func(t *testing.T) {
mockRepo := testutils.NewMockUserRepository()
err := userCreate(cfg, mockRepo, []string{
"--username", "testuser",
"--email", "test@example.com",
"--password",
})
if err == nil {
t.Error("expected error for missing password value")
}
})
t.Run("unknown flag", func(t *testing.T) {
mockRepo := testutils.NewMockUserRepository()
err := userCreate(cfg, mockRepo, []string{
"--username", "testuser",
"--email", "test@example.com",
"--password", "StrongPass123!",
"--unknown-flag",
})
if err == nil {
t.Error("expected error for unknown flag")
}
})
t.Run("duplicate flag", func(t *testing.T) {
mockRepo := testutils.NewMockUserRepository()
err := userCreate(cfg, mockRepo, []string{
"--username", "testuser",
"--email", "test@example.com",
"--password", "StrongPass123!",
"--username", "duplicate",
})
if err != nil {
if !strings.Contains(err.Error(), "required") && !strings.Contains(err.Error(), "validation") {
t.Errorf("unexpected error type: %v", err)
}
}
})
}
func TestUserUpdate(t *testing.T) {
mockRepo := testutils.NewMockUserRepository()
testUser := &database.User{
Username: "testuser",
Email: "test@example.com",
Password: "hashedpassword",
}
_ = mockRepo.Create(testUser)
t.Run("successful update username", func(t *testing.T) {
cfg := &config.Config{}
mockRefreshRepo := &mockRefreshTokenRepo{}
err := userUpdate(cfg, mockRepo, mockRefreshRepo, []string{
"1",
"--username", "newusername",
})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("successful update email", func(t *testing.T) {
cfg := &config.Config{}
mockRefreshRepo := &mockRefreshTokenRepo{}
err := userUpdate(cfg, mockRepo, mockRefreshRepo, []string{
"1",
"--email", "newemail@example.com",
})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("successful update password", func(t *testing.T) {
cfg := testutils.NewTestConfig()
mockRefreshRepo := &mockRefreshTokenRepo{}
err := userUpdate(cfg, mockRepo, mockRefreshRepo, []string{
"1",
"--password", "NewStrongPass123!",
})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("missing id", func(t *testing.T) {
cfg := &config.Config{}
mockRefreshRepo := &mockRefreshTokenRepo{}
err := userUpdate(cfg, mockRepo, mockRefreshRepo, []string{})
if err == nil {
t.Error("expected error for missing id")
}
if err.Error() != "user ID is required" {
t.Errorf("expected specific error, got: %v", err)
}
})
t.Run("invalid id", func(t *testing.T) {
cfg := &config.Config{}
mockRefreshRepo := &mockRefreshTokenRepo{}
err := userUpdate(cfg, mockRepo, mockRefreshRepo, []string{
"0",
"--username", "newusername",
})
if err == nil {
t.Error("expected error for invalid id")
}
if err.Error() != "user ID must be greater than 0" {
t.Errorf("expected specific error, got: %v", err)
}
})
t.Run("user not found", func(t *testing.T) {
cfg := &config.Config{}
mockRefreshRepo := &mockRefreshTokenRepo{}
err := userUpdate(cfg, mockRepo, mockRefreshRepo, []string{
"999",
"--username", "newusername",
})
if err == nil {
t.Error("expected error for non-existent user")
}
expectedErr := "user 999 not found"
if err.Error() != expectedErr {
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
}
})
t.Run("password too short", func(t *testing.T) {
cfg := &config.Config{}
mockRefreshRepo := &mockRefreshTokenRepo{}
err := userUpdate(cfg, mockRepo, mockRefreshRepo, []string{
"1",
"--password", "short",
})
if err == nil {
t.Error("expected error for short password")
}
if !strings.Contains(err.Error(), "password must be at least 8 characters") {
t.Errorf("expected password length error, got: %v", err)
}
})
}
func TestUserDelete(t *testing.T) {
cfg := testutils.NewTestConfig()
mockRepo := testutils.NewMockUserRepository()
testUser := &database.User{
Username: "testuser",
Email: "test@example.com",
Password: "hashedpassword",
}
_ = mockRepo.Create(testUser)
t.Run("successful delete (keep posts)", func(t *testing.T) {
err := userDelete(cfg, mockRepo, []string{"1"})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("successful delete with posts", func(t *testing.T) {
testUser2 := &database.User{
Username: "testuser2",
Email: "test2@example.com",
Password: "hashedpassword",
}
_ = mockRepo.Create(testUser2)
err := userDelete(cfg, mockRepo, []string{"2", "--with-posts"})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("missing id", func(t *testing.T) {
err := userDelete(cfg, mockRepo, []string{})
if err == nil {
t.Error("expected error for missing id")
}
if err.Error() != "user ID is required" {
t.Errorf("expected specific error, got: %v", err)
}
})
t.Run("invalid id", func(t *testing.T) {
err := userDelete(cfg, mockRepo, []string{"0"})
if err == nil {
t.Error("expected error for invalid id")
}
if err.Error() != "user ID must be greater than 0" {
t.Errorf("expected specific error, got: %v", err)
}
})
t.Run("user not found", func(t *testing.T) {
err := userDelete(cfg, mockRepo, []string{"999"})
if err == nil {
t.Error("expected error for non-existent user")
}
if !strings.Contains(err.Error(), "not found") {
t.Errorf("expected 'not found' error, got: %v", err)
}
})
t.Run("user already deleted", func(t *testing.T) {
freshMockRepo := testutils.NewMockUserRepository()
testUser := &database.User{
Username: "deleteduser",
Email: "deleted@example.com",
Password: "hashedpassword",
}
_ = freshMockRepo.Create(testUser)
err := userDelete(cfg, freshMockRepo, []string{"1"})
if err != nil {
t.Errorf("unexpected error on first deletion: %v", err)
}
err = userDelete(cfg, freshMockRepo, []string{"1"})
if err == nil {
t.Error("expected error for already deleted user")
}
if !strings.Contains(err.Error(), "not found") {
t.Errorf("expected 'not found' error, got: %v", err)
}
})
}
func TestUserList(t *testing.T) {
mockRepo := testutils.NewMockUserRepository()
testUsers := []*database.User{
{
Username: "user1",
Email: "user1@example.com",
Password: "password1",
CreatedAt: time.Now().Add(-2 * time.Hour),
},
{
Username: "user2",
Email: "user2@example.com",
Password: "password2",
CreatedAt: time.Now().Add(-1 * time.Hour),
},
}
for _, user := range testUsers {
_ = mockRepo.Create(user)
}
t.Run("list all users", func(t *testing.T) {
err := userList(mockRepo, []string{})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("list with limit", func(t *testing.T) {
err := userList(mockRepo, []string{"--limit", "1"})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("list with offset", func(t *testing.T) {
err := userList(mockRepo, []string{"--offset", "1"})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("list with all filters", func(t *testing.T) {
err := userList(mockRepo, []string{"--limit", "1", "--offset", "0"})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("empty result", func(t *testing.T) {
emptyRepo := testutils.NewMockUserRepository()
err := userList(emptyRepo, []string{})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("repository error", func(t *testing.T) {
mockRepo.GetErr = errors.New("database error")
err := userList(mockRepo, []string{})
if err == nil {
t.Error("expected error from repository")
}
expectedErr := "list users: database error"
if err.Error() != expectedErr {
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
}
})
t.Run("invalid limit type", func(t *testing.T) {
err := userList(mockRepo, []string{"--limit", "abc"})
if err == nil {
t.Error("expected error for invalid limit type")
}
})
t.Run("invalid offset type", func(t *testing.T) {
err := userList(mockRepo, []string{"--offset", "xyz"})
if err == nil {
t.Error("expected error for invalid offset type")
}
})
t.Run("unknown flag", func(t *testing.T) {
err := userList(mockRepo, []string{"--unknown-flag"})
if err == nil {
t.Error("expected error for unknown flag")
}
})
t.Run("missing limit value", func(t *testing.T) {
err := userList(mockRepo, []string{"--limit"})
if err == nil {
t.Error("expected error for missing limit value")
}
})
t.Run("missing offset value", func(t *testing.T) {
err := userList(mockRepo, []string{"--offset"})
if err == nil {
t.Error("expected error for missing offset value")
}
})
}
func TestCheckUsernameAvailable(t *testing.T) {
mockRepo := testutils.NewMockUserRepository()
testUser := &database.User{
Username: "existinguser",
Email: "test@example.com",
Password: "password",
}
_ = mockRepo.Create(testUser)
t.Run("username available", func(t *testing.T) {
err := checkUsernameAvailable(mockRepo, "newuser", 0)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("username taken by different user", func(t *testing.T) {
err := checkUsernameAvailable(mockRepo, "existinguser", 2)
if err == nil {
t.Error("expected error for taken username")
}
expectedErr := "username existinguser is already taken"
if err.Error() != expectedErr {
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
}
})
t.Run("username taken by same user (should be ok)", func(t *testing.T) {
err := checkUsernameAvailable(mockRepo, "existinguser", 1)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
}
func TestCheckEmailAvailable(t *testing.T) {
mockRepo := testutils.NewMockUserRepository()
testUser := &database.User{
Username: "testuser",
Email: "existing@example.com",
Password: "password",
}
_ = mockRepo.Create(testUser)
t.Run("email available", func(t *testing.T) {
err := checkEmailAvailable(mockRepo, "new@example.com", 0)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
t.Run("email taken by different user", func(t *testing.T) {
err := checkEmailAvailable(mockRepo, "existing@example.com", 2)
if err == nil {
t.Error("expected error for taken email")
}
expectedErr := "email existing@example.com is already registered"
if err.Error() != expectedErr {
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
}
})
t.Run("email taken by same user (should be ok)", func(t *testing.T) {
err := checkEmailAvailable(mockRepo, "existing@example.com", 1)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
}
func TestGenerateTemporaryPassword(t *testing.T) {
for range 10 {
password, err := generateTemporaryPassword()
if err != nil {
t.Fatalf("generateTemporaryPassword() error = %v", err)
}
if len(password) != 16 {
t.Errorf("Password length = %d, want 16", len(password))
}
hasUpper := false
hasLower := false
hasDigit := false
hasSpecial := false
for _, char := range password {
switch {
case char >= 'A' && char <= 'Z':
hasUpper = true
case char >= 'a' && char <= 'z':
hasLower = true
case char >= '0' && char <= '9':
hasDigit = true
case char == '!' || char == '@' || char == '#' || char == '$' || char == '%' || char == '^' || char == '&' || char == '*':
hasSpecial = true
}
}
if !hasUpper {
t.Errorf("Password %s missing uppercase letter", password)
}
if !hasLower {
t.Errorf("Password %s missing lowercase letter", password)
}
if !hasDigit {
t.Errorf("Password %s missing digit", password)
}
if !hasSpecial {
t.Errorf("Password %s missing special character", password)
}
}
}
func TestGenerateTemporaryPassword_Uniqueness(t *testing.T) {
passwords := make(map[string]bool)
for range 100 {
password, err := generateTemporaryPassword()
if err != nil {
t.Fatalf("generateTemporaryPassword() error = %v", err)
}
if passwords[password] {
t.Errorf("Duplicate password generated: %s", password)
}
passwords[password] = true
}
}
func TestResetUserPassword_WithoutEmail(t *testing.T) {
tempPassword, err := generateTemporaryPassword()
if err != nil {
t.Fatalf("generateTemporaryPassword() error = %v", err)
}
if len(tempPassword) != 16 {
t.Errorf("Password length = %d, want 16", len(tempPassword))
}
hasUpper := false
hasLower := false
hasDigit := false
hasSpecial := false
for _, char := range tempPassword {
switch {
case char >= 'A' && char <= 'Z':
hasUpper = true
case char >= 'a' && char <= 'z':
hasLower = true
case char >= '0' && char <= '9':
hasDigit = true
case char == '!' || char == '@' || char == '#' || char == '$' || char == '%' || char == '^' || char == '&' || char == '*':
hasSpecial = true
}
}
if !hasUpper {
t.Error("Password missing uppercase letter")
}
if !hasLower {
t.Error("Password missing lowercase letter")
}
if !hasDigit {
t.Error("Password missing digit")
}
if !hasSpecial {
t.Error("Password missing special character")
}
}
type mockRefreshTokenRepo struct{}
func (m *mockRefreshTokenRepo) Create(token *database.RefreshToken) error { return nil }
func (m *mockRefreshTokenRepo) GetByTokenHash(tokenHash string) (*database.RefreshToken, error) {
return nil, nil
}
func (m *mockRefreshTokenRepo) DeleteByUserID(userID uint) error { return nil }
func (m *mockRefreshTokenRepo) DeleteExpired() error { return nil }
func (m *mockRefreshTokenRepo) DeleteByID(id uint) error { return nil }
func (m *mockRefreshTokenRepo) GetByUserID(userID uint) ([]database.RefreshToken, error) {
return nil, nil
}
func (m *mockRefreshTokenRepo) CountByUserID(userID uint) (int64, error) { return 0, nil }
func TestResetUserPassword_UserNotFound(t *testing.T) {
mockRepo := testutils.NewMockUserRepository()
mockRefreshRepo := &mockRefreshTokenRepo{}
cfg := &config.Config{
JWT: config.JWTConfig{Secret: "test-secret", Expiration: 24},
}
jwtService := services.NewJWTService(&cfg.JWT, mockRepo, mockRefreshRepo)
mockSessionService := services.NewSessionService(jwtService, mockRepo)
err := resetUserPassword(cfg, mockRepo, mockSessionService, 999)
if err == nil {
t.Error("Expected error for non-existent user, got nil")
}
expectedError := "user 999 not found"
if err.Error() != expectedError {
t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
}
}
func TestGeneratePasswordResetEmailBody(t *testing.T) {
username := "testuser"
title := "Test Title"
tempPassword := "TempPass123!"
baseURL := "https://example.com"
adminEmail := "admin@example.com"
body := generatePasswordResetEmailBody(username, tempPassword, baseURL, adminEmail, title)
if !strings.Contains(body, username) {
t.Error("Email body does not contain username")
}
if !strings.Contains(body, tempPassword) {
t.Error("Email body does not contain temporary password")
}
if !strings.Contains(body, baseURL) {
t.Error("Email body does not contain base URL")
}
if !strings.Contains(body, "IMPORTANT SECURITY NOTICE") {
t.Error("Email body does not contain security notice")
}
if !strings.Contains(body, "<!DOCTYPE html>") {
t.Error("Email body is not HTML")
}
if !strings.Contains(body, "mailto:"+adminEmail) {
t.Error("Email body does not contain admin contact link")
}
}

208
cmd/goyco/fuzz_test.go Normal file
View File

@@ -0,0 +1,208 @@
package main
import (
"flag"
"fmt"
"os"
"strings"
"testing"
"unicode/utf8"
"goyco/cmd/goyco/commands"
"goyco/internal/config"
"goyco/internal/testutils"
"gorm.io/gorm"
)
func FuzzCLIArgs(f *testing.F) {
f.Add("")
f.Add("run")
f.Add("--help")
f.Add("user list")
f.Add("post search")
f.Add("migrate")
f.Fuzz(func(t *testing.T, input string) {
if !isValidUTF8(input) {
return
}
if len(input) > 1000 {
input = input[:1000]
}
args := strings.Fields(input)
if len(args) == 0 {
return
}
fs := flag.NewFlagSet("goyco", flag.ContinueOnError)
fs.SetOutput(os.Stderr)
fs.Usage = printRootUsage
showHelp := fs.Bool("help", false, "show this help message")
err := fs.Parse(args)
if err != nil {
if !strings.Contains(err.Error(), "flag") && !strings.Contains(err.Error(), "help") {
t.Logf("Unexpected error format from flag parsing: %v", err)
}
}
if *showHelp && err != nil {
return
}
remaining := fs.Args()
if len(remaining) > 0 {
cmdName := remaining[0]
if len(cmdName) == 0 {
t.Fatal("Command name cannot be empty")
}
if !isValidUTF8(cmdName) {
t.Fatal("Command name must be valid UTF-8")
}
}
})
}
func FuzzCommandDispatch(f *testing.F) {
cfg := testutils.NewTestConfig()
setRunServer(func(_ *config.Config, _ bool) error {
return nil
})
defer setRunServer(runServerImpl)
originalRunServer := runServerImpl
commands.SetRunServer(func(_ *config.Config, _ bool) error {
return nil
})
defer commands.SetRunServer(originalRunServer)
commands.SetDaemonize(func() (int, error) {
return 999, nil
})
defer commands.SetDaemonize(nil)
commands.SetSetupDaemonLogging(func(_ *config.Config, _ string) error {
return nil
})
defer commands.SetSetupDaemonLogging(nil)
commands.SetDBConnector(func(_ *config.Config) (*gorm.DB, func() error, error) {
return nil, nil, fmt.Errorf("database connection disabled in fuzzer")
})
defer commands.SetDBConnector(nil)
daemonCommands := map[string]bool{
"start": true,
"stop": true,
"status": true,
}
f.Add("run")
f.Add("help")
f.Add("user")
f.Add("post")
f.Add("migrate")
f.Add("unknown_command")
f.Add("--help")
f.Add("-h")
f.Fuzz(func(t *testing.T, input string) {
if !isValidUTF8(input) {
return
}
parts := strings.Fields(input)
if len(parts) == 0 {
return
}
cmdName := parts[0]
args := parts[1:]
if daemonCommands[cmdName] {
return
}
err := dispatchCommand(cfg, cmdName, args)
knownCommands := map[string]bool{
"run": true, "user": true, "post": true, "prune": true, "migrate": true,
"migrations": true, "seed": true, "help": true, "-h": true, "--help": true,
}
if knownCommands[cmdName] {
if err != nil && !strings.Contains(err.Error(), cmdName) {
t.Logf("Known command %q returned unexpected error: %v", cmdName, err)
}
} else {
if err == nil {
t.Fatalf("Unknown command %q should return an error", cmdName)
}
if !strings.Contains(err.Error(), cmdName) {
t.Fatalf("Error for unknown command should contain command name: %v", err)
}
}
})
}
func FuzzRunCommandHandler(f *testing.F) {
cfg := testutils.NewTestConfig()
setRunServer(func(_ *config.Config, _ bool) error {
return nil
})
defer setRunServer(runServerImpl)
f.Add("")
f.Add("--help")
f.Add("extra arg")
f.Add("--invalid")
f.Fuzz(func(t *testing.T, input string) {
if !isValidUTF8(input) {
return
}
args := strings.Fields(input)
err := handleRunCommand(cfg, args)
if len(args) > 0 && args[0] == "--help" {
if err != nil {
t.Logf("Help flag should not error, got: %v", err)
}
} else if len(args) > 0 {
if err == nil {
return
}
errMsg := err.Error()
if strings.Contains(errMsg, "flag provided but not defined") ||
strings.Contains(errMsg, "failed to parse") {
return
}
if !strings.Contains(errMsg, "unexpected arguments") {
t.Logf("Got error (may be acceptable for server setup): %v", err)
}
} else {
if err != nil && strings.Contains(err.Error(), "unexpected arguments") {
t.Fatalf("Empty args should not trigger 'unexpected arguments' error: %v", err)
}
}
})
}
func isValidUTF8(s string) bool {
for _, r := range s {
if r == utf8.RuneError {
return false
}
}
return true
}

136
cmd/goyco/main.go Normal file
View File

@@ -0,0 +1,136 @@
// @title Goyco API
// @version 0.1.0
// @description Goyco is a Y Combinator-style news aggregation platform API.
// @contact.name Goyco Team
// @contact.email sandro@cazzaniga.fr
// @license.name GPLv3
// @license.url https://www.gnu.org/licenses/gpl-3.0.html
// @host localhost:8080
// @schemes http
// @BasePath /api
package main
import (
"errors"
"flag"
"fmt"
"log"
"os"
"goyco/cmd/goyco/commands"
"goyco/docs"
"goyco/internal/config"
"goyco/internal/version"
)
func main() {
loadDotEnv()
commands.SetRunServer(runServerImpl)
if len(os.Args) > 1 && os.Args[len(os.Args)-1] == "--daemon" {
args := os.Args[1 : len(os.Args)-1]
if err := commands.RunDaemonProcessDirect(args); err != nil {
log.Fatalf("daemon error: %v", err)
}
return
}
if err := run(os.Args[1:]); err != nil {
log.Fatalf("error: %v", err)
}
}
func run(args []string) error {
cfg, err := config.Load()
if err != nil {
return fmt.Errorf("load configuration: %w", err)
}
validator := commands.NewConfigValidator(nil)
if err := validator.ValidateConfiguration(cfg); err != nil {
return fmt.Errorf("configuration validation failed: %w", err)
}
docs.SwaggerInfo.Title = fmt.Sprintf("%s API", cfg.App.Title)
docs.SwaggerInfo.Description = "Y Combinator-style news board API."
docs.SwaggerInfo.Version = version.Version
docs.SwaggerInfo.BasePath = "/api"
docs.SwaggerInfo.Host = fmt.Sprintf("%s:%s", cfg.Server.Host, cfg.Server.Port)
docs.SwaggerInfo.Schemes = []string{"http"}
if cfg.Server.EnableTLS {
docs.SwaggerInfo.Schemes = append(docs.SwaggerInfo.Schemes, "https")
}
rootFS := flag.NewFlagSet("goyco", flag.ContinueOnError)
rootFS.SetOutput(os.Stderr)
rootFS.Usage = printRootUsage
showHelp := rootFS.Bool("help", false, "show this help message")
if err := rootFS.Parse(args); err != nil {
if errors.Is(err, flag.ErrHelp) {
return nil
}
return fmt.Errorf("failed to parse arguments: %w", err)
}
if *showHelp {
printRootUsage()
return nil
}
remaining := rootFS.Args()
if len(remaining) == 0 {
printRootUsage()
return nil
}
return dispatchCommand(cfg, remaining[0], remaining[1:])
}
func dispatchCommand(cfg *config.Config, name string, args []string) error {
switch name {
case "run":
return handleRunCommand(cfg, args)
case "start":
return commands.HandleStartCommand(cfg, args)
case "stop":
return commands.HandleStopCommand(cfg, args)
case "status":
return commands.HandleStatusCommand(cfg, name, args)
case "user":
return commands.HandleUserCommand(cfg, name, args)
case "post":
return commands.HandlePostCommand(cfg, name, args)
case "prune":
return commands.HandlePruneCommand(cfg, name, args)
case "migrate", "migrations":
return commands.HandleMigrateCommand(cfg, name, args)
case "seed":
return commands.HandleSeedCommand(cfg, name, args)
case "help", "-h", "--help":
printRootUsage()
return nil
default:
printRootUsage()
return fmt.Errorf("unknown command: %s", name)
}
}
func handleRunCommand(cfg *config.Config, args []string) error {
fs := newFlagSet("run", printRunUsage)
if err := parseCommand(fs, args, "run"); err != nil {
if errors.Is(err, commands.ErrHelpRequested) {
return nil
}
return err
}
if fs.NArg() > 0 {
printRunUsage()
return errors.New("unexpected arguments for run command")
}
return runServer(cfg, false)
}

149
cmd/goyco/server.go Normal file
View File

@@ -0,0 +1,149 @@
package main
import (
"crypto/tls"
"fmt"
"log"
"net/http"
"goyco/cmd/goyco/commands"
"goyco/internal/config"
"goyco/internal/database"
"goyco/internal/handlers"
"goyco/internal/middleware"
"goyco/internal/repositories"
"goyco/internal/server"
"goyco/internal/services"
_ "goyco/docs"
)
func runServerImpl(cfg *config.Config, daemon bool) error {
if daemon {
if err := commands.SetupDaemonLogging(cfg, cfg.LogDir); err != nil {
return fmt.Errorf("setup daemon logging: %w", err)
}
}
dbMonitor := middleware.NewInMemoryDBMonitor()
poolManager, err := database.ConnectWithPool(cfg)
if err != nil {
return fmt.Errorf("connect to database: %w", err)
}
defer func() {
middleware.StopAllRateLimiters()
if err := poolManager.Close(); err != nil {
log.Printf("Error closing database pool: %v", err)
}
}()
db := poolManager.GetDB()
if err := database.Migrate(db); err != nil {
return fmt.Errorf("run migrations: %w", err)
}
if monitor := dbMonitor; monitor != nil {
monitoringPlugin := database.NewGormDBMonitor(monitor)
if err := db.Use(monitoringPlugin); err != nil {
return fmt.Errorf("failed to add monitoring plugin: %w", err)
}
}
voteRepository := repositories.NewVoteRepository(db)
postRepository := repositories.NewPostRepository(db)
userRepository := repositories.NewUserRepository(db)
deletionRepository := repositories.NewAccountDeletionRepository(db)
refreshTokenRepository := repositories.NewRefreshTokenRepository(db)
emailSender := services.NewSMTPSender(cfg.SMTP.Host, cfg.SMTP.Port, cfg.SMTP.Username, cfg.SMTP.Password, cfg.SMTP.From)
emailService, err := services.NewEmailService(cfg, emailSender)
if err != nil {
return fmt.Errorf("create email service: %w", err)
}
jwtService := services.NewJWTService(&cfg.JWT, userRepository, refreshTokenRepository)
registrationService := services.NewRegistrationService(userRepository, emailService, cfg)
passwordResetService := services.NewPasswordResetService(userRepository, emailService)
deletionService := services.NewAccountDeletionService(userRepository, postRepository, deletionRepository, emailService)
sessionService := services.NewSessionService(jwtService, userRepository)
userManagementService := services.NewUserManagementService(userRepository, postRepository, emailService)
authFacade := services.NewAuthFacade(
registrationService,
passwordResetService,
deletionService,
sessionService,
userManagementService,
cfg,
)
voteService := services.NewVoteService(voteRepository, postRepository, db)
voteHandler := handlers.NewVoteHandler(voteService)
metadataService := services.NewURLMetadataService()
postHandler := handlers.NewPostHandler(postRepository, metadataService, voteService)
userHandler := handlers.NewUserHandler(userRepository, authFacade)
authHandler := handlers.NewAuthHandler(authFacade, userRepository)
apiHandler := handlers.NewAPIHandlerWithMonitoring(cfg, postRepository, userRepository, voteService, db, dbMonitor)
pageHandler, err := handlers.NewPageHandler("./internal/templates", authFacade, postRepository, voteService, userRepository, metadataService, cfg)
if err != nil {
return fmt.Errorf("load templates: %w", err)
}
router := server.NewRouter(server.RouterConfig{
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
APIHandler: apiHandler,
AuthService: authFacade,
PageHandler: pageHandler,
StaticDir: "./internal/static/",
Debug: cfg.App.Debug,
DBMonitor: dbMonitor,
RateLimitConfig: cfg.RateLimit,
})
serverAddr := cfg.Server.Host + ":" + cfg.Server.Port
log.Printf("Server starting on %s", serverAddr)
srv := &http.Server{
Addr: serverAddr,
Handler: router,
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
IdleTimeout: cfg.Server.IdleTimeout,
MaxHeaderBytes: cfg.Server.MaxHeaderBytes,
}
if cfg.Server.EnableTLS {
log.Printf("TLS enabled")
srv.TLSConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
CipherSuites: []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
},
}
return srv.ListenAndServeTLS(cfg.Server.TLSCertFile, cfg.Server.TLSKeyFile)
}
log.Printf("WARNING: Server is running on plain HTTP. Enable TLS for production use.")
return srv.ListenAndServe()
}
var runServer = runServerImpl
func setRunServer(fn func(cfg *config.Config, daemon bool) error) {
runServer = fn
}

393
cmd/goyco/server_test.go Normal file
View File

@@ -0,0 +1,393 @@
package main
import (
"crypto/tls"
"errors"
"flag"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"goyco/internal/config"
"goyco/internal/database"
"goyco/internal/handlers"
"goyco/internal/repositories"
"goyco/internal/server"
"goyco/internal/services"
"goyco/internal/testutils"
)
func TestServerConfigurationFromConfig(t *testing.T) {
cfg := testutils.NewTestConfig()
cfg.Server.ReadTimeout = 30 * time.Second
cfg.Server.WriteTimeout = 30 * time.Second
cfg.Server.IdleTimeout = 120 * time.Second
cfg.Server.MaxHeaderBytes = 1 << 20
db := testutils.NewTestDB(t)
defer func() {
sqlDB, _ := db.DB()
_ = sqlDB.Close()
}()
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
voteRepo := repositories.NewVoteRepository(db)
deletionRepo := repositories.NewAccountDeletionRepository(db)
refreshTokenRepo := repositories.NewRefreshTokenRepository(db)
emailSender := &testutils.MockEmailSender{}
authService, err := services.NewAuthFacadeForTest(cfg, userRepo, postRepo, deletionRepo, refreshTokenRepo, emailSender)
if err != nil {
t.Fatalf("Failed to create auth service: %v", err)
}
voteService := services.NewVoteService(voteRepo, postRepo, db)
metadataService := services.NewURLMetadataService()
authHandler := handlers.NewAuthHandler(authService, userRepo)
postHandler := handlers.NewPostHandler(postRepo, metadataService, voteService)
voteHandler := handlers.NewVoteHandler(voteService)
userHandler := handlers.NewUserHandler(userRepo, authService)
apiHandler := handlers.NewAPIHandler(cfg, postRepo, userRepo, voteService)
router := server.NewRouter(server.RouterConfig{
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
APIHandler: apiHandler,
AuthService: authService,
StaticDir: "./internal/static/",
Debug: cfg.App.Debug,
DisableCache: true,
DisableCompression: true,
RateLimitConfig: cfg.RateLimit,
})
srv := &http.Server{
Addr: cfg.Server.Host + ":" + cfg.Server.Port,
Handler: router,
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
IdleTimeout: cfg.Server.IdleTimeout,
MaxHeaderBytes: cfg.Server.MaxHeaderBytes,
}
if srv.ReadTimeout != 30*time.Second {
t.Errorf("Expected ReadTimeout to be 30s, got %v", srv.ReadTimeout)
}
if srv.WriteTimeout != 30*time.Second {
t.Errorf("Expected WriteTimeout to be 30s, got %v", srv.WriteTimeout)
}
if srv.IdleTimeout != 120*time.Second {
t.Errorf("Expected IdleTimeout to be 120s, got %v", srv.IdleTimeout)
}
if srv.MaxHeaderBytes != 1<<20 {
t.Errorf("Expected MaxHeaderBytes to be 1MB, got %d", srv.MaxHeaderBytes)
}
testServer := httptest.NewServer(srv.Handler)
defer testServer.Close()
resp, err := http.Get(testServer.URL + "/health")
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
}
func TestTLSWiringFromConfig(t *testing.T) {
cfg := testutils.NewTestConfig()
cfg.Server.EnableTLS = true
cfg.Server.TLSCertFile = "/tmp/nonexistent-cert.pem"
cfg.Server.TLSKeyFile = "/tmp/nonexistent-key.pem"
db := testutils.NewTestDB(t)
defer func() {
sqlDB, _ := db.DB()
_ = sqlDB.Close()
}()
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
voteRepo := repositories.NewVoteRepository(db)
deletionRepo := repositories.NewAccountDeletionRepository(db)
refreshTokenRepo := repositories.NewRefreshTokenRepository(db)
emailSender := &testutils.MockEmailSender{}
authService, err := services.NewAuthFacadeForTest(cfg, userRepo, postRepo, deletionRepo, refreshTokenRepo, emailSender)
if err != nil {
t.Fatalf("Failed to create auth service: %v", err)
}
voteService := services.NewVoteService(voteRepo, postRepo, db)
metadataService := services.NewURLMetadataService()
authHandler := handlers.NewAuthHandler(authService, userRepo)
postHandler := handlers.NewPostHandler(postRepo, metadataService, voteService)
voteHandler := handlers.NewVoteHandler(voteService)
userHandler := handlers.NewUserHandler(userRepo, authService)
apiHandler := handlers.NewAPIHandler(cfg, postRepo, userRepo, voteService)
router := server.NewRouter(server.RouterConfig{
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
APIHandler: apiHandler,
AuthService: authService,
StaticDir: "./internal/static/",
Debug: cfg.App.Debug,
DisableCache: true,
DisableCompression: true,
RateLimitConfig: cfg.RateLimit,
})
expectedAddr := cfg.Server.Host + ":" + cfg.Server.Port
srv := &http.Server{
Addr: expectedAddr,
Handler: router,
}
if srv.Addr != expectedAddr {
t.Errorf("Expected server address to be %q, got %q", expectedAddr, srv.Addr)
}
if cfg.Server.EnableTLS {
srv.TLSConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
},
}
if srv.TLSConfig == nil {
t.Error("Expected TLS config to be set")
}
if srv.TLSConfig.MinVersion < tls.VersionTLS12 {
t.Error("Expected minimum TLS version to be 1.2 or higher")
}
if len(srv.TLSConfig.CipherSuites) == 0 {
t.Error("Expected cipher suites to be configured")
}
testServer := httptest.NewUnstartedServer(srv.Handler)
testServer.TLS = srv.TLSConfig
testServer.StartTLS()
defer testServer.Close()
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
}
resp, err := client.Get(testServer.URL + "/health")
if err != nil {
t.Fatalf("Failed to make TLS request: %v", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 over TLS, got %d", resp.StatusCode)
}
if resp.TLS == nil {
t.Error("Expected TLS connection info to be present in response")
} else {
if resp.TLS.Version < tls.VersionTLS12 {
t.Errorf("Expected TLS version 1.2 or higher, got %x", resp.TLS.Version)
}
}
}
}
func TestConfigLoadingInCLI(t *testing.T) {
originalEnv := os.Environ()
defer func() {
os.Clearenv()
for _, env := range originalEnv {
parts := splitEnv(env)
if len(parts) == 2 {
_ = os.Setenv(parts[0], parts[1])
}
}
}()
os.Clearenv()
_ = os.Setenv("DB_PASSWORD", "test-password-123")
_ = os.Setenv("SMTP_HOST", "smtp.example.com")
_ = os.Setenv("SMTP_FROM", "test@example.com")
_ = os.Setenv("ADMIN_EMAIL", "admin@example.com")
_ = os.Setenv("JWT_SECRET", "test-jwt-secret-key-that-is-long-enough-for-validation")
cfg, err := config.Load()
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
if cfg.Server.Port == "" {
t.Error("Expected server port to be set")
}
if cfg.Database.Host == "" {
t.Error("Expected database host to be set")
}
}
func TestFlagParsingInCLI(t *testing.T) {
originalArgs := os.Args
defer func() {
os.Args = originalArgs
}()
t.Run("help flag", func(t *testing.T) {
os.Args = []string{"goyco", "--help"}
fs := flag.NewFlagSet("goyco", flag.ContinueOnError)
fs.SetOutput(os.Stderr)
showHelp := fs.Bool("help", false, "show help")
err := fs.Parse([]string{"--help"})
if err != nil && !errors.Is(err, flag.ErrHelp) {
t.Errorf("Expected help flag parsing, got error: %v", err)
}
if !*showHelp {
t.Error("Expected help flag to be true")
}
})
t.Run("command dispatch", func(t *testing.T) {
cfg := testutils.NewTestConfig()
err := dispatchCommand(cfg, "unknown", []string{})
if err == nil {
t.Error("Expected error for unknown command")
}
err = dispatchCommand(cfg, "help", []string{})
if err != nil {
t.Errorf("Help command should not error: %v", err)
}
})
}
func TestServerInitializationFlow(t *testing.T) {
cfg := testutils.NewTestConfig()
cfg.Server.Port = "0"
db := testutils.NewTestDB(t)
defer func() {
sqlDB, _ := db.DB()
_ = sqlDB.Close()
}()
if err := database.Migrate(db); err != nil {
t.Fatalf("Failed to run migrations: %v", err)
}
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
voteRepo := repositories.NewVoteRepository(db)
deletionRepo := repositories.NewAccountDeletionRepository(db)
refreshTokenRepo := repositories.NewRefreshTokenRepository(db)
emailSender := &testutils.MockEmailSender{}
authService, err := services.NewAuthFacadeForTest(cfg, userRepo, postRepo, deletionRepo, refreshTokenRepo, emailSender)
if err != nil {
t.Fatalf("Failed to create auth service: %v", err)
}
voteService := services.NewVoteService(voteRepo, postRepo, db)
metadataService := services.NewURLMetadataService()
authHandler := handlers.NewAuthHandler(authService, userRepo)
postHandler := handlers.NewPostHandler(postRepo, metadataService, voteService)
voteHandler := handlers.NewVoteHandler(voteService)
userHandler := handlers.NewUserHandler(userRepo, authService)
apiHandler := handlers.NewAPIHandler(cfg, postRepo, userRepo, voteService)
router := server.NewRouter(server.RouterConfig{
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
APIHandler: apiHandler,
AuthService: authService,
StaticDir: "./internal/static/",
Debug: cfg.App.Debug,
DisableCache: true,
DisableCompression: true,
RateLimitConfig: cfg.RateLimit,
})
srv := &http.Server{
Addr: cfg.Server.Host + ":" + cfg.Server.Port,
Handler: router,
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
IdleTimeout: cfg.Server.IdleTimeout,
MaxHeaderBytes: cfg.Server.MaxHeaderBytes,
}
if srv.Handler == nil {
t.Error("Expected server handler to be set")
}
testServer := httptest.NewServer(srv.Handler)
defer testServer.Close()
resp, err := http.Get(testServer.URL + "/health")
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
resp, err = http.Get(testServer.URL + "/api")
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 for API endpoint, got %d", resp.StatusCode)
}
}
func splitEnv(env string) []string {
for i := 0; i < len(env); i++ {
if env[i] == '=' {
return []string{env[:i], env[i+1:]}
}
}
return []string{env}
}