package services import ( "context" "errors" "fmt" "io" "net" "net/http" "net/url" "strings" "sync" "time" "golang.org/x/net/html" ) var ( ErrUnsupportedScheme = errors.New("unsupported URL scheme") ErrTitleNotFound = errors.New("page title not found") ErrSSRFBlocked = errors.New("request blocked for security reasons") ErrTooManyRedirects = errors.New("too many redirects") ) const ( maxTitleBodyBytes = 512 * 1024 defaultUserAgent = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" maxRedirects = 3 requestTimeout = 10 * time.Second dialTimeout = 5 * time.Second tlsHandshakeTimeout = 5 * time.Second responseHeaderTimeout = 5 * time.Second maxContentLength = 10 * 1024 * 1024 ) type TitleFetcher interface { FetchTitle(ctx context.Context, rawURL string) (string, error) } type DNSResolver interface { LookupIP(hostname string) ([]net.IP, error) } type DefaultDNSResolver struct{} func (d DefaultDNSResolver) LookupIP(hostname string) ([]net.IP, error) { return net.LookupIP(hostname) } type DNSCache struct { mu sync.RWMutex data map[string][]net.IP } func NewDNSCache() *DNSCache { return &DNSCache{ data: make(map[string][]net.IP), } } func (c *DNSCache) Get(hostname string) ([]net.IP, bool) { c.mu.RLock() defer c.mu.RUnlock() ips, exists := c.data[hostname] return ips, exists } func (c *DNSCache) Set(hostname string, ips []net.IP) { c.mu.Lock() defer c.mu.Unlock() c.data[hostname] = ips } type CachedDNSResolver struct { resolver DNSResolver cache *DNSCache } func NewCachedDNSResolver(resolver DNSResolver) *CachedDNSResolver { return &CachedDNSResolver{ resolver: resolver, cache: NewDNSCache(), } } func (c *CachedDNSResolver) LookupIP(hostname string) ([]net.IP, error) { if ips, exists := c.cache.Get(hostname); exists { return ips, nil } ips, err := c.resolver.LookupIP(hostname) if err != nil { return nil, err } c.cache.Set(hostname, ips) return ips, nil } type CustomDialer struct { cache *DNSCache fallback *net.Dialer } func NewCustomDialer(cache *DNSCache) *CustomDialer { return &CustomDialer{ cache: cache, fallback: &net.Dialer{ Timeout: dialTimeout, }, } } func (d *CustomDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { host, port, err := net.SplitHostPort(address) if err != nil { return d.fallback.DialContext(ctx, network, address) } if ips, exists := d.cache.Get(host); exists { for _, ip := range ips { ipAddr := net.JoinHostPort(ip.String(), port) if conn, err := d.fallback.DialContext(ctx, network, ipAddr); err == nil { return conn, nil } } } return d.fallback.DialContext(ctx, network, address) } type URLMetadataService struct { client *http.Client resolver DNSResolver dnsCache *DNSCache approvedHosts map[string]bool mu sync.RWMutex } func NewURLMetadataService() *URLMetadataService { dnsCache := NewDNSCache() cachedResolver := NewCachedDNSResolver(DefaultDNSResolver{}) customDialer := NewCustomDialer(dnsCache) svc := &URLMetadataService{ resolver: cachedResolver, dnsCache: dnsCache, approvedHosts: make(map[string]bool), } transport := &http.Transport{ DialContext: customDialer.DialContext, MaxIdleConns: 100, MaxIdleConnsPerHost: 10, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: tlsHandshakeTimeout, ResponseHeaderTimeout: responseHeaderTimeout, DisableKeepAlives: false, } svc.client = &http.Client{ Timeout: requestTimeout, Transport: transport, CheckRedirect: func(req *http.Request, via []*http.Request) error { if len(via) >= maxRedirects { return ErrTooManyRedirects } hostname := req.URL.Hostname() svc.mu.RLock() approved := svc.approvedHosts[hostname] svc.mu.RUnlock() if approved { return nil } if err := svc.validateURLForSSRF(req.URL); err != nil { return err } svc.mu.Lock() svc.approvedHosts[hostname] = true svc.mu.Unlock() return nil }, } return svc } func (s *URLMetadataService) FetchTitle(ctx context.Context, rawURL string) (string, error) { if rawURL == "" { return "", errors.New("empty URL") } parsed, err := url.Parse(rawURL) if err != nil { return "", fmt.Errorf("parse url: %w", err) } if parsed.Scheme != "http" && parsed.Scheme != "https" { return "", ErrUnsupportedScheme } hostname := parsed.Hostname() s.mu.RLock() approved := s.approvedHosts[hostname] s.mu.RUnlock() if !approved { if err := s.validateURLForSSRF(parsed); err != nil { return "", err } s.mu.Lock() s.approvedHosts[hostname] = true s.mu.Unlock() } request, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil) if err != nil { return "", fmt.Errorf("build request: %w", err) } request.Header.Set("User-Agent", defaultUserAgent) request.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8") request.Header.Set("Accept-Language", "en-US,en;q=0.5") resp, err := s.client.Do(request) if err != nil { return "", fmt.Errorf("fetch url: %w", err) } defer resp.Body.Close() if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode) } contentType := resp.Header.Get("Content-Type") if !strings.Contains(strings.ToLower(contentType), "text/html") { return "", ErrTitleNotFound } contentLength := resp.ContentLength if contentLength > maxContentLength { return "", ErrTitleNotFound } limited := io.LimitReader(resp.Body, maxTitleBodyBytes) body, err := io.ReadAll(limited) if err != nil { return "", fmt.Errorf("read body: %w", err) } title := s.ExtractTitleFromHTML(string(body)) if title != "" { return title, nil } return "", ErrTitleNotFound } func (s *URLMetadataService) ExtractTitleFromHTML(html string) string { if title := s.ExtractFromTitleTag(html); title != "" { return title } if title := s.ExtractFromOpenGraph(html); title != "" { return title } if title := s.ExtractFromJSONLD(html); title != "" { return title } if title := s.ExtractFromTwitterCard(html); title != "" { return title } if title := s.extractFromMetaTags(html); title != "" { return title } return "" } func (s *URLMetadataService) ExtractFromTitleTag(htmlContent string) string { tokenizer := html.NewTokenizer(strings.NewReader(htmlContent)) for { tokenType := tokenizer.Next() switch tokenType { case html.ErrorToken: if errors.Is(tokenizer.Err(), io.EOF) { return "" } return "" case html.StartTagToken, html.SelfClosingTagToken: token := tokenizer.Token() if strings.EqualFold(token.Data, "title") { textTokenType := tokenizer.Next() if textTokenType == html.TextToken { rawTitle := tokenizer.Token().Data cleaned := s.optimizedTitleClean(rawTitle) if cleaned != "" { return cleaned } } } } } } func (s *URLMetadataService) ExtractFromOpenGraph(htmlContent string) string { lines := strings.Split(htmlContent, "\n") for _, line := range lines { line = strings.TrimSpace(line) if strings.Contains(strings.ToLower(line), `property="og:title"`) && strings.Contains(line, `content="`) { start := strings.Index(line, `content="`) if start != -1 { start += 9 end := strings.Index(line[start:], `"`) if end != -1 { title := line[start : start+end] cleaned := s.optimizedTitleClean(title) if cleaned != "" { return cleaned } } } } } return "" } func (s *URLMetadataService) ExtractFromJSONLD(htmlContent string) string { lines := strings.Split(htmlContent, "\n") for _, line := range lines { line = strings.TrimSpace(line) if strings.Contains(line, `"@type":"VideoObject"`) || strings.Contains(line, `"@type":"WebPage"`) { if strings.Contains(line, `"name":`) { start := strings.Index(line, `"name":`) if start != -1 { start += 7 for i := start; i < len(line); i++ { if line[i] == '"' { start = i + 1 break } } end := strings.Index(line[start:], `"`) if end != -1 { title := line[start : start+end] cleaned := s.optimizedTitleClean(title) if cleaned != "" { return cleaned } } } } } } return "" } func (s *URLMetadataService) ExtractFromTwitterCard(htmlContent string) string { lines := strings.Split(htmlContent, "\n") for _, line := range lines { line = strings.TrimSpace(line) if strings.Contains(strings.ToLower(line), `name="twitter:title"`) && strings.Contains(line, `content="`) { start := strings.Index(line, `content="`) if start != -1 { start += 9 end := strings.Index(line[start:], `"`) if end != -1 { title := line[start : start+end] cleaned := s.optimizedTitleClean(title) if cleaned != "" { return cleaned } } } } } return "" } func (s *URLMetadataService) extractFromMetaTags(htmlContent string) string { lines := strings.Split(htmlContent, "\n") for _, line := range lines { line = strings.TrimSpace(line) if strings.Contains(strings.ToLower(line), `name="title"`) && strings.Contains(line, `content="`) { start := strings.Index(line, `content="`) if start != -1 { start += 9 end := strings.Index(line[start:], `"`) if end != -1 { title := line[start : start+end] cleaned := s.optimizedTitleClean(title) if cleaned != "" { return cleaned } } } } } return "" } func (s *URLMetadataService) optimizedTitleClean(title string) string { if title == "" { return "" } var result strings.Builder result.Grow(len(title)) inWhitespace := false started := false for _, r := range title { if r == ' ' || r == '\t' || r == '\n' || r == '\r' { if started && !inWhitespace { result.WriteRune(' ') inWhitespace = true } } else { result.WriteRune(r) inWhitespace = false started = true } } cleaned := result.String() if len(cleaned) > 0 && cleaned[len(cleaned)-1] == ' ' { cleaned = cleaned[:len(cleaned)-1] } return cleaned } func (s *URLMetadataService) validateURLForSSRF(u *url.URL) error { switch { case u == nil, u.Scheme != "http" && u.Scheme != "https", u.Host == "", u.Hostname() == "", isLocalhost(u.Hostname()): return ErrSSRFBlocked } ips, err := s.resolver.LookupIP(u.Hostname()) if err != nil { return ErrSSRFBlocked } for _, ip := range ips { if isPrivateOrReservedIP(ip) { return ErrSSRFBlocked } } return nil } func isLocalhost(hostname string) bool { hostname = strings.ToLower(hostname) localhostNames := []string{ "localhost", "127.0.0.1", "::1", "0.0.0.0", "0:0:0:0:0:0:0:1", "0:0:0:0:0:0:0:0", } for _, name := range localhostNames { if hostname == name { return true } } return false } func isPrivateOrReservedIP(ip net.IP) bool { if ip == nil { return true } ipv4 := ip.To4() if ipv4 == nil { return isPrivateIPv6(ip) } privateRanges := []struct { start, end net.IP }{ {net.IPv4(10, 0, 0, 0), net.IPv4(10, 255, 255, 255)}, {net.IPv4(172, 16, 0, 0), net.IPv4(172, 31, 255, 255)}, {net.IPv4(192, 168, 0, 0), net.IPv4(192, 168, 255, 255)}, {net.IPv4(127, 0, 0, 0), net.IPv4(127, 255, 255, 255)}, {net.IPv4(169, 254, 0, 0), net.IPv4(169, 254, 255, 255)}, {net.IPv4(224, 0, 0, 0), net.IPv4(239, 255, 255, 255)}, {net.IPv4(240, 0, 0, 0), net.IPv4(255, 255, 255, 255)}, } for _, r := range privateRanges { if ipInRange(ipv4, r.start, r.end) { return true } } return false } func isPrivateIPv6(ip net.IP) bool { privateRanges := []struct { prefix []byte length int }{ {[]byte{0xfc, 0x00}, 7}, {[]byte{0xfe, 0x80}, 10}, {[]byte{0xff, 0x00}, 8}, {[]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, 128}, } for _, r := range privateRanges { if ipv6InRange(ip, r.prefix, r.length) { return true } } return false } func ipInRange(ip, start, end net.IP) bool { ipInt := ipToInt(ip) startInt := ipToInt(start) endInt := ipToInt(end) return ipInt >= startInt && ipInt <= endInt } func ipToInt(ip net.IP) uint32 { ipv4 := ip.To4() if ipv4 == nil { return 0 } return uint32(ipv4[0])<<24 + uint32(ipv4[1])<<16 + uint32(ipv4[2])<<8 + uint32(ipv4[3]) } func ipv6InRange(ip net.IP, prefix []byte, length int) bool { ipBytes := ip.To16() if ipBytes == nil { return false } bytesToCompare := length / 8 bitsToCompare := length % 8 for i := 0; i < bytesToCompare && i < len(prefix) && i < len(ipBytes); i++ { if ipBytes[i] != prefix[i] { return false } } if bitsToCompare > 0 && bytesToCompare < len(prefix) && bytesToCompare < len(ipBytes) { mask := byte(0xff) << (8 - bitsToCompare) if (ipBytes[bytesToCompare] & mask) != (prefix[bytesToCompare] & mask) { return false } } return true }