To gitea and beyond, let's go(-yco)
This commit is contained in:
89
internal/fuzz/db.go
Normal file
89
internal/fuzz/db.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package fuzz
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
fuzzDBOnce sync.Once
|
||||
fuzzDB *gorm.DB
|
||||
fuzzDBErr error
|
||||
)
|
||||
|
||||
func GetFuzzDB() (*gorm.DB, error) {
|
||||
fuzzDBOnce.Do(func() {
|
||||
dbName := "file:memdb_fuzz?mode=memory&cache=shared&_journal_mode=WAL&_synchronous=NORMAL"
|
||||
fuzzDB, fuzzDBErr = gorm.Open(sqlite.Open(dbName), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if fuzzDBErr == nil {
|
||||
fuzzDBErr = fuzzDB.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT UNIQUE NOT NULL,
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
password TEXT NOT NULL,
|
||||
email_verified INTEGER DEFAULT 0 NOT NULL,
|
||||
email_verified_at DATETIME,
|
||||
email_verification_token TEXT,
|
||||
email_verification_sent_at DATETIME,
|
||||
password_reset_token TEXT,
|
||||
password_reset_sent_at DATETIME,
|
||||
password_reset_expires_at DATETIME,
|
||||
locked INTEGER DEFAULT 0,
|
||||
session_version INTEGER DEFAULT 1 NOT NULL,
|
||||
created_at DATETIME,
|
||||
updated_at DATETIME,
|
||||
deleted_at DATETIME
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS posts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
title TEXT NOT NULL,
|
||||
url TEXT UNIQUE,
|
||||
content TEXT,
|
||||
author_id INTEGER,
|
||||
author_name TEXT,
|
||||
up_votes INTEGER DEFAULT 0,
|
||||
down_votes INTEGER DEFAULT 0,
|
||||
score INTEGER DEFAULT 0,
|
||||
created_at DATETIME,
|
||||
updated_at DATETIME,
|
||||
deleted_at DATETIME,
|
||||
FOREIGN KEY(author_id) REFERENCES users(id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS votes (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER,
|
||||
post_id INTEGER NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
vote_hash TEXT,
|
||||
created_at DATETIME,
|
||||
updated_at DATETIME,
|
||||
FOREIGN KEY(user_id) REFERENCES users(id),
|
||||
FOREIGN KEY(post_id) REFERENCES posts(id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS account_deletion_requests (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL,
|
||||
token_hash TEXT UNIQUE NOT NULL,
|
||||
expires_at DATETIME NOT NULL,
|
||||
created_at DATETIME,
|
||||
FOREIGN KEY(user_id) REFERENCES users(id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL,
|
||||
token_hash TEXT UNIQUE NOT NULL,
|
||||
expires_at DATETIME NOT NULL,
|
||||
created_at DATETIME,
|
||||
FOREIGN KEY(user_id) REFERENCES users(id)
|
||||
);
|
||||
`).Error
|
||||
}
|
||||
})
|
||||
return fuzzDB, fuzzDBErr
|
||||
}
|
||||
226
internal/fuzz/fuzz.go
Normal file
226
internal/fuzz/fuzz.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package fuzz
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
type FuzzTestHelper struct{}
|
||||
|
||||
func NewFuzzTestHelper() *FuzzTestHelper {
|
||||
return &FuzzTestHelper{}
|
||||
}
|
||||
|
||||
func (h *FuzzTestHelper) RunBasicFuzzTest(f *testing.F, testFunc func(t *testing.T, input string)) {
|
||||
f.Add("test input")
|
||||
f.Fuzz(func(t *testing.T, input string) {
|
||||
if !utf8.ValidString(input) {
|
||||
return
|
||||
}
|
||||
testFunc(t, input)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *FuzzTestHelper) RunValidationFuzzTest(f *testing.F, validateFunc func(string) error) {
|
||||
h.RunBasicFuzzTest(f, func(t *testing.T, input string) {
|
||||
err := validateFunc(input)
|
||||
_ = err
|
||||
})
|
||||
}
|
||||
|
||||
func (h *FuzzTestHelper) RunSanitizationFuzzTest(f *testing.F, sanitizeFunc func(string) string) {
|
||||
h.RunBasicFuzzTest(f, func(t *testing.T, input string) {
|
||||
result := sanitizeFunc(input)
|
||||
if !utf8.ValidString(result) {
|
||||
t.Fatal("Sanitized result contains invalid UTF-8")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (h *FuzzTestHelper) RunSanitizationFuzzTestWithValidation(f *testing.F, sanitizeFunc func(string) string, validateFunc func(string) bool) {
|
||||
h.RunBasicFuzzTest(f, func(t *testing.T, input string) {
|
||||
result := sanitizeFunc(input)
|
||||
if !utf8.ValidString(result) {
|
||||
t.Fatal("Sanitized result contains invalid UTF-8")
|
||||
}
|
||||
if validateFunc != nil {
|
||||
if !validateFunc(result) {
|
||||
t.Fatal("Sanitized result failed validation")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (h *FuzzTestHelper) RunJSONFuzzTest(f *testing.F, testCases []map[string]any) {
|
||||
h.RunBasicFuzzTest(f, func(t *testing.T, input string) {
|
||||
for _, tc := range testCases {
|
||||
body, ok := tc["body"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
encoded, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
encodedStr := string(encoded)
|
||||
body = strings.ReplaceAll(body, "FUZZED_INPUT", encodedStr)
|
||||
|
||||
var result map[string]any
|
||||
err = json.Unmarshal([]byte(body), &result)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (h *FuzzTestHelper) RunHTTPFuzzTest(f *testing.F, testCases []HTTPFuzzTestCase) {
|
||||
h.RunBasicFuzzTest(f, func(t *testing.T, input string) {
|
||||
for _, tc := range testCases {
|
||||
|
||||
sanitized := h.sanitizeForURL(input)
|
||||
|
||||
url := strings.ReplaceAll(tc.URL, "FUZZED_INPUT", sanitized)
|
||||
body := strings.ReplaceAll(tc.Body, "FUZZED_INPUT", sanitized)
|
||||
|
||||
req := httptest.NewRequest(tc.Method, url, bytes.NewBufferString(body))
|
||||
|
||||
for name, value := range tc.Headers {
|
||||
req.Header.Set(name, value)
|
||||
}
|
||||
|
||||
h.validateHTTPRequest(t, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (h *FuzzTestHelper) sanitizeForURL(input string) string {
|
||||
sanitized := strings.ReplaceAll(input, "\n", "")
|
||||
sanitized = strings.ReplaceAll(sanitized, "\r", "")
|
||||
sanitized = strings.ReplaceAll(sanitized, "\t", "")
|
||||
sanitized = url.QueryEscape(sanitized)
|
||||
sanitized = strings.ReplaceAll(sanitized, "+", "%20")
|
||||
|
||||
if len(sanitized) > 100 {
|
||||
sanitized = sanitized[:100]
|
||||
}
|
||||
|
||||
return sanitized
|
||||
}
|
||||
|
||||
type HTTPFuzzTestCase struct {
|
||||
Name string
|
||||
Method string
|
||||
URL string
|
||||
Headers map[string]string
|
||||
Body string
|
||||
}
|
||||
|
||||
func (h *FuzzTestHelper) validateHTTPRequest(t *testing.T, req *http.Request) {
|
||||
pathParts := strings.Split(req.URL.Path, "/")
|
||||
for _, part := range pathParts {
|
||||
if !utf8.ValidString(part) {
|
||||
t.Fatal("Path contains invalid UTF-8")
|
||||
}
|
||||
}
|
||||
|
||||
for name, values := range req.URL.Query() {
|
||||
if !utf8.ValidString(name) {
|
||||
t.Fatal("Query parameter name contains invalid UTF-8")
|
||||
}
|
||||
for _, value := range values {
|
||||
if !utf8.ValidString(value) {
|
||||
t.Fatal("Query parameter value contains invalid UTF-8")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for name, values := range req.Header {
|
||||
if !utf8.ValidString(name) {
|
||||
t.Fatal("Header name contains invalid UTF-8")
|
||||
}
|
||||
for _, value := range values {
|
||||
if !utf8.ValidString(value) {
|
||||
t.Fatal("Header value contains invalid UTF-8")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *FuzzTestHelper) RunIntegrationFuzzTest(f *testing.F, testFunc func(t *testing.T, input string)) {
|
||||
h.RunBasicFuzzTest(f, func(t *testing.T, input string) {
|
||||
|
||||
if len(input) > 1000 {
|
||||
input = input[:1000]
|
||||
}
|
||||
|
||||
testFunc(t, input)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *FuzzTestHelper) GetCommonAuthTestCases(input string) []HTTPFuzzTestCase {
|
||||
return []HTTPFuzzTestCase{
|
||||
{
|
||||
Name: "auth_register",
|
||||
Method: "POST",
|
||||
URL: "/api/auth/register",
|
||||
Headers: map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
Body: `{"username":"FUZZED_INPUT","email":"test@example.com","password":"test123"}`,
|
||||
},
|
||||
{
|
||||
Name: "auth_login",
|
||||
Method: "POST",
|
||||
URL: "/api/auth/login",
|
||||
Headers: map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
Body: `{"username":"FUZZED_INPUT","password":"test123"}`,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (h *FuzzTestHelper) GetCommonPostTestCases(input string) []HTTPFuzzTestCase {
|
||||
return []HTTPFuzzTestCase{
|
||||
{
|
||||
Name: "post_create",
|
||||
Method: "POST",
|
||||
URL: "/api/posts",
|
||||
Headers: map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer FUZZED_INPUT",
|
||||
},
|
||||
Body: `{"title":"FUZZED_INPUT","url":"https://example.com","content":"test"}`,
|
||||
},
|
||||
{
|
||||
Name: "post_search",
|
||||
Method: "GET",
|
||||
URL: "/api/posts/search?q=FUZZED_INPUT",
|
||||
Headers: map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (h *FuzzTestHelper) GetCommonVoteTestCases(input string) []HTTPFuzzTestCase {
|
||||
return []HTTPFuzzTestCase{
|
||||
{
|
||||
Name: "vote_cast",
|
||||
Method: "POST",
|
||||
URL: "/api/posts/1/vote",
|
||||
Headers: map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer FUZZED_INPUT",
|
||||
},
|
||||
Body: `{"type":"FUZZED_INPUT"}`,
|
||||
},
|
||||
}
|
||||
}
|
||||
1724
internal/fuzz/fuzz_test.go
Normal file
1724
internal/fuzz/fuzz_test.go
Normal file
File diff suppressed because it is too large
Load Diff
298
internal/fuzz/integration_fuzz_test.go
Normal file
298
internal/fuzz/integration_fuzz_test.go
Normal file
@@ -0,0 +1,298 @@
|
||||
package fuzz
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"goyco/internal/handlers"
|
||||
"goyco/internal/middleware"
|
||||
"goyco/internal/repositories"
|
||||
"goyco/internal/services"
|
||||
"goyco/internal/testutils"
|
||||
)
|
||||
|
||||
func FuzzIntegrationHandlers(f *testing.F) {
|
||||
f.Add("testuser")
|
||||
f.Add("test@example.com")
|
||||
f.Add("password123")
|
||||
f.Add("")
|
||||
f.Add("<script>alert('xss')</script>")
|
||||
|
||||
f.Fuzz(func(t *testing.T, input string) {
|
||||
if len(input) > 500 {
|
||||
input = input[:500]
|
||||
}
|
||||
|
||||
if !isValidUTF8(input) {
|
||||
return
|
||||
}
|
||||
|
||||
db := testutils.NewTestDB(t)
|
||||
defer func() {
|
||||
sqlDB, _ := db.DB()
|
||||
sqlDB.Close()
|
||||
}()
|
||||
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
postRepo := repositories.NewPostRepository(db)
|
||||
voteRepo := repositories.NewVoteRepository(db)
|
||||
deletionRepo := repositories.NewAccountDeletionRepository(db)
|
||||
refreshTokenRepo := repositories.NewRefreshTokenRepository(db)
|
||||
emailSender := &testutils.MockEmailSender{}
|
||||
|
||||
authService, err := services.NewAuthFacadeForTest(testutils.AppTestConfig, userRepo, postRepo, deletionRepo, refreshTokenRepo, emailSender)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create auth service: %v", err)
|
||||
}
|
||||
|
||||
voteService := services.NewVoteService(voteRepo, postRepo, db)
|
||||
titleFetcher := &testutils.MockTitleFetcher{}
|
||||
|
||||
authHandler := handlers.NewAuthHandler(authService, userRepo)
|
||||
postHandler := handlers.NewPostHandler(postRepo, titleFetcher, voteService)
|
||||
apiHandler := handlers.NewAPIHandler(testutils.AppTestConfig, postRepo, userRepo, voteService)
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Use(middleware.Logging(false))
|
||||
router.Use(middleware.SecurityHeadersMiddleware())
|
||||
router.Use(middleware.GeneralRateLimitMiddleware())
|
||||
|
||||
router.Route("/api", func(r chi.Router) {
|
||||
r.Post("/auth/register", authHandler.Register)
|
||||
r.Post("/auth/login", authHandler.Login)
|
||||
r.Get("/posts/search", postHandler.SearchPosts)
|
||||
r.Get("/posts", postHandler.GetPosts)
|
||||
|
||||
r.Group(func(protected chi.Router) {
|
||||
protected.Use(middleware.NewAuth(authService))
|
||||
protected.Get("/auth/me", authHandler.Me)
|
||||
protected.Post("/posts", postHandler.CreatePost)
|
||||
})
|
||||
})
|
||||
|
||||
router.Get("/health", apiHandler.GetHealth)
|
||||
|
||||
t.Run("register_endpoint", func(t *testing.T) {
|
||||
username := input[:min(len(input), 50)]
|
||||
email := input[:min(len(input), 50)] + "@example.com"
|
||||
password := input[:min(len(input), 128)]
|
||||
if len(password) < 8 {
|
||||
password = password + "12345678"
|
||||
}
|
||||
|
||||
registerBody := fmt.Sprintf(`{"username":"%s","email":"%s","password":"%s"}`,
|
||||
escapeJSON(username), escapeJSON(email), escapeJSON(password))
|
||||
|
||||
req, _ := http.NewRequest("POST", "/api/auth/register", bytes.NewBufferString(registerBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code == 0 {
|
||||
t.Fatal("Handler should return a status code")
|
||||
}
|
||||
|
||||
if resp.Code != http.StatusCreated && resp.Code != http.StatusBadRequest {
|
||||
t.Logf("Unexpected status code %d for register (expected 201 or 400)", resp.Code)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("Response should be valid JSON: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("search_endpoint", func(t *testing.T) {
|
||||
query := input[:min(len(input), 200)]
|
||||
escapedQuery := url.QueryEscape(query)
|
||||
|
||||
req, _ := http.NewRequest("GET", "/api/posts/search?q="+escapedQuery, nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code == 0 {
|
||||
t.Fatal("Handler should return a status code")
|
||||
}
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Logf("Unexpected status code %d for search (expected 200)", resp.Code)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("Response should be valid JSON: %v", err)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func FuzzIntegrationServices(f *testing.F) {
|
||||
f.Add("testuser")
|
||||
f.Add("test@example.com")
|
||||
f.Add("password123")
|
||||
f.Add("")
|
||||
f.Add("a")
|
||||
f.Add(strings.Repeat("x", 100))
|
||||
|
||||
f.Fuzz(func(t *testing.T, input string) {
|
||||
if len(input) > 200 {
|
||||
input = input[:200]
|
||||
}
|
||||
|
||||
if !utf8.ValidString(input) {
|
||||
return
|
||||
}
|
||||
|
||||
db := testutils.NewTestDB(t)
|
||||
defer func() {
|
||||
sqlDB, _ := db.DB()
|
||||
sqlDB.Close()
|
||||
}()
|
||||
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
postRepo := repositories.NewPostRepository(db)
|
||||
deletionRepo := repositories.NewAccountDeletionRepository(db)
|
||||
refreshTokenRepo := repositories.NewRefreshTokenRepository(db)
|
||||
emailSender := &testutils.MockEmailSender{}
|
||||
|
||||
authService, err := services.NewAuthFacadeForTest(testutils.AppTestConfig, userRepo, postRepo, deletionRepo, refreshTokenRepo, emailSender)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create auth service: %v", err)
|
||||
}
|
||||
|
||||
usernameLen := len(input)
|
||||
if usernameLen > 50 {
|
||||
usernameLen = 50
|
||||
}
|
||||
username := input[:usernameLen]
|
||||
email := input[:usernameLen] + "@example.com"
|
||||
|
||||
passwordLen := len(input)
|
||||
if passwordLen > 128 {
|
||||
passwordLen = 128
|
||||
}
|
||||
password := input[:passwordLen]
|
||||
|
||||
if len(password) < 8 {
|
||||
password = password + "12345678"
|
||||
}
|
||||
|
||||
result, err := authService.Register(username, email, password)
|
||||
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "panic") || strings.Contains(err.Error(), "nil pointer") {
|
||||
t.Fatalf("Registration should not panic: %v", err)
|
||||
}
|
||||
} else {
|
||||
if result.User == nil {
|
||||
t.Fatal("Registration result should contain a user")
|
||||
}
|
||||
if result.User.Username != username {
|
||||
t.Fatalf("Expected username %q, got %q", username, result.User.Username)
|
||||
}
|
||||
if !strings.EqualFold(result.User.Email, email) {
|
||||
t.Fatalf("Expected email %q, got %q", email, result.User.Email)
|
||||
}
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
loginResult, loginErr := authService.Login(username, password)
|
||||
if loginErr == nil {
|
||||
if loginResult.User == nil {
|
||||
t.Fatal("Login result should contain a user")
|
||||
}
|
||||
if loginResult.User.Username != username {
|
||||
t.Fatalf("Expected username %q, got %q", username, loginResult.User.Username)
|
||||
}
|
||||
if loginResult.AccessToken == "" {
|
||||
t.Fatal("Login result should contain an access token")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func FuzzIntegrationRepositories(f *testing.F) {
|
||||
helper := NewFuzzTestHelper()
|
||||
helper.RunIntegrationFuzzTest(f, func(t *testing.T, fuzzedData string) {
|
||||
searchQuery := fuzzedData
|
||||
if len(searchQuery) > 100 {
|
||||
searchQuery = searchQuery[:100]
|
||||
}
|
||||
|
||||
sanitizer := repositories.NewSearchSanitizer()
|
||||
sanitizedQuery, err := sanitizer.SanitizeSearchQuery(searchQuery)
|
||||
|
||||
if err == nil {
|
||||
if !utf8.ValidString(sanitizedQuery) {
|
||||
t.Fatal("String contains invalid UTF-8")
|
||||
}
|
||||
|
||||
validationErr := sanitizer.ValidateSearchQuery(sanitizedQuery)
|
||||
_ = validationErr
|
||||
}
|
||||
|
||||
username := fuzzedData
|
||||
email := fuzzedData + "@example.com"
|
||||
|
||||
if len(username) > 50 {
|
||||
username = username[:50]
|
||||
}
|
||||
if len(email) > 100 {
|
||||
email = email[:100]
|
||||
}
|
||||
|
||||
if !utf8.ValidString(username) {
|
||||
t.Fatal("String contains invalid UTF-8")
|
||||
}
|
||||
if !utf8.ValidString(email) {
|
||||
t.Fatal("String contains invalid UTF-8")
|
||||
}
|
||||
|
||||
postTitle := fuzzedData
|
||||
postContent := fuzzedData
|
||||
|
||||
if len(postTitle) > 200 {
|
||||
postTitle = postTitle[:200]
|
||||
}
|
||||
if len(postContent) > 1000 {
|
||||
postContent = postContent[:1000]
|
||||
}
|
||||
|
||||
if !utf8.ValidString(postTitle) {
|
||||
t.Fatal("String contains invalid UTF-8")
|
||||
}
|
||||
if !utf8.ValidString(postContent) {
|
||||
t.Fatal("String contains invalid UTF-8")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func isValidUTF8(s string) bool {
|
||||
for _, r := range s {
|
||||
if r == utf8.RuneError {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func escapeJSON(s string) string {
|
||||
s = strings.ReplaceAll(s, "\\", "\\\\")
|
||||
s = strings.ReplaceAll(s, "\"", "\\\"")
|
||||
s = strings.ReplaceAll(s, "\n", "\\n")
|
||||
s = strings.ReplaceAll(s, "\r", "\\r")
|
||||
s = strings.ReplaceAll(s, "\t", "\\t")
|
||||
return s
|
||||
}
|
||||
187
internal/fuzz/repositories_fuzz_test.go
Normal file
187
internal/fuzz/repositories_fuzz_test.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package fuzz
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
|
||||
"goyco/internal/repositories"
|
||||
)
|
||||
|
||||
func FuzzSearchRepository(f *testing.F) {
|
||||
f.Add("test query")
|
||||
f.Add("")
|
||||
f.Add("SELECT * FROM posts")
|
||||
f.Add(strings.Repeat("a", 1000))
|
||||
f.Add("<script>alert('xss')</script>")
|
||||
|
||||
f.Fuzz(func(t *testing.T, input string) {
|
||||
if len(input) > 1000 {
|
||||
input = input[:1000]
|
||||
}
|
||||
|
||||
if !utf8.ValidString(input) {
|
||||
return
|
||||
}
|
||||
|
||||
db, err := GetFuzzDB()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to test database: %v", err)
|
||||
}
|
||||
|
||||
db.Exec("DELETE FROM votes")
|
||||
db.Exec("DELETE FROM posts")
|
||||
db.Exec("DELETE FROM users")
|
||||
db.Exec("DELETE FROM account_deletion_requests")
|
||||
db.Exec("DELETE FROM refresh_tokens")
|
||||
|
||||
postRepo := repositories.NewPostRepository(db)
|
||||
sanitizer := repositories.NewSearchSanitizer()
|
||||
|
||||
t.Run("sanitize_and_search", func(t *testing.T) {
|
||||
sanitized, err := sanitizer.SanitizeSearchQuery(input)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !utf8.ValidString(sanitized) {
|
||||
t.Fatalf("Sanitized query should be valid UTF-8: %q", sanitized)
|
||||
}
|
||||
|
||||
posts, searchErr := postRepo.Search(sanitized, 1, 10)
|
||||
if searchErr != nil {
|
||||
if strings.Contains(searchErr.Error(), "panic") {
|
||||
t.Fatalf("Search should not panic: %v", searchErr)
|
||||
}
|
||||
} else {
|
||||
if posts != nil {
|
||||
_ = len(posts)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("validate_search_query", func(t *testing.T) {
|
||||
err := sanitizer.ValidateSearchQuery(input)
|
||||
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "panic") {
|
||||
t.Fatalf("ValidateSearchQuery should not panic: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func FuzzPostRepository(f *testing.F) {
|
||||
f.Add("test title")
|
||||
f.Add("")
|
||||
f.Add("<script>alert('xss')</script>")
|
||||
f.Add("https://example.com")
|
||||
f.Add(strings.Repeat("a", 500))
|
||||
|
||||
f.Fuzz(func(t *testing.T, input string) {
|
||||
if len(input) > 500 {
|
||||
input = input[:500]
|
||||
}
|
||||
|
||||
if !utf8.ValidString(input) {
|
||||
return
|
||||
}
|
||||
|
||||
db, err := GetFuzzDB()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to test database: %v", err)
|
||||
}
|
||||
|
||||
db.Exec("DELETE FROM votes")
|
||||
db.Exec("DELETE FROM posts")
|
||||
db.Exec("DELETE FROM users")
|
||||
db.Exec("DELETE FROM account_deletion_requests")
|
||||
db.Exec("DELETE FROM refresh_tokens")
|
||||
|
||||
postRepo := repositories.NewPostRepository(db)
|
||||
|
||||
var userID uint
|
||||
result := db.Exec(`
|
||||
INSERT INTO users (username, email, password, email_verified, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, datetime('now'), datetime('now'))
|
||||
`, "fuzz_test_user", "fuzz@example.com", "hashedpassword", true)
|
||||
if result.Error != nil {
|
||||
t.Fatalf("Failed to create test user: %v", result.Error)
|
||||
}
|
||||
|
||||
var createdUser struct {
|
||||
ID uint `gorm:"column:id"`
|
||||
}
|
||||
db.Raw("SELECT id FROM users WHERE username = ?", "fuzz_test_user").Scan(&createdUser)
|
||||
userID = createdUser.ID
|
||||
|
||||
t.Run("create_and_get_post", func(t *testing.T) {
|
||||
title := input[:min(len(input), 200)]
|
||||
url := "https://example.com/" + input[:min(len(input), 50)]
|
||||
content := input[:min(len(input), 1000)]
|
||||
|
||||
result := db.Exec(`
|
||||
INSERT INTO posts (title, url, content, author_id, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, datetime('now'), datetime('now'))
|
||||
`, title, url, content, userID)
|
||||
if result.Error != nil {
|
||||
if strings.Contains(result.Error.Error(), "panic") {
|
||||
t.Fatalf("Create should not panic: %v", result.Error)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var postID uint
|
||||
var createdPost struct {
|
||||
ID uint `gorm:"column:id"`
|
||||
}
|
||||
db.Raw("SELECT id FROM posts WHERE author_id = ? ORDER BY id DESC LIMIT 1", userID).Scan(&createdPost)
|
||||
postID = createdPost.ID
|
||||
|
||||
if postID == 0 {
|
||||
t.Fatal("Created post should have an ID")
|
||||
}
|
||||
|
||||
retrieved, getErr := postRepo.GetByID(postID)
|
||||
if getErr != nil {
|
||||
t.Fatalf("GetByID should succeed for created post: %v", getErr)
|
||||
}
|
||||
|
||||
if retrieved == nil {
|
||||
t.Fatal("GetByID should return a post")
|
||||
}
|
||||
|
||||
if retrieved.ID != postID {
|
||||
t.Fatalf("Expected post ID %d, got %d", postID, retrieved.ID)
|
||||
}
|
||||
|
||||
posts, listErr := postRepo.GetAll(10, 0)
|
||||
if listErr != nil {
|
||||
t.Fatalf("GetAll should not error: %v", listErr)
|
||||
}
|
||||
|
||||
if posts == nil {
|
||||
t.Fatal("GetAll should return a slice")
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, p := range posts {
|
||||
if p.ID == postID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found && len(posts) > 0 {
|
||||
t.Logf("Created post not found in list (this may be acceptable depending on pagination)")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
Reference in New Issue
Block a user