refactor: switch from net.IP to netip.Addr

This change switches net.IP to netip.Addr added in Go 1.18. This results in slightly better performance and memory utilization for very large threat feeds (over 500,000 entries).
This commit is contained in:
Ryan Smith
2025-05-08 16:26:33 -07:00
parent 375da6eeac
commit f9d7b767bc
4 changed files with 37 additions and 37 deletions

View File

@@ -5,7 +5,7 @@ import (
"encoding/csv"
"errors"
"fmt"
"net"
"net/netip"
"os"
"strconv"
"strings"
@@ -65,11 +65,8 @@ var (
func Update(ip string) {
// Check if the given IP string is a private address. The threat feed may
// be configured to include or exclude private IPs.
netIP := net.ParseIP(ip)
if netIP == nil || netIP.IsLoopback() {
return
}
if !cfg.ThreatFeed.IsPrivateIncluded && netIP.IsPrivate() {
parsedIP, err := netip.ParseAddr(ip)
if err != nil || parsedIP.IsLoopback() || (!cfg.ThreatFeed.IsPrivateIncluded && parsedIP.IsPrivate()) {
return
}

View File

@@ -2,10 +2,9 @@ package threatfeed
import (
"bufio"
"bytes"
"cmp"
"fmt"
"net"
"net/netip"
"os"
"slices"
"strings"
@@ -16,11 +15,11 @@ import (
// feedEntry represents an individual entry in the threat feed.
type feedEntry struct {
IP string `json:"ip"`
IPBytes net.IP `json:"-"`
Added time.Time `json:"added"`
LastSeen time.Time `json:"last_seen"`
Observations int `json:"observations"`
IP string `json:"ip"`
IPBytes netip.Addr `json:"-"`
Added time.Time `json:"added"`
LastSeen time.Time `json:"last_seen"`
Observations int `json:"observations"`
}
// feedEntries is a slice of feedEntry structs. It represents the threat feed
@@ -86,13 +85,13 @@ loop:
continue
}
parsedIP := net.ParseIP(ip)
if parsedIP == nil || (parsedIP.IsPrivate() && !cfg.ThreatFeed.IsPrivateIncluded) {
parsedIP, err := netip.ParseAddr(ip)
if err != nil || (parsedIP.IsPrivate() && !cfg.ThreatFeed.IsPrivateIncluded) {
continue
}
for _, ipnet := range excludedCIDR {
if ipnet.Contains(parsedIP) {
for _, prefix := range excludedCIDR {
if prefix.Contains(parsedIP) {
continue loop
}
}
@@ -120,9 +119,9 @@ loop:
// should contain an IP address or CIDR. It returns a map of the unique IPs and
// a slice of the CIDR ranges found in the file. The file may include comments
// using "#". The "#" symbol on a line and everything after is ignored.
func parseExcludeList(filepath string) (map[string]struct{}, []*net.IPNet, error) {
func parseExcludeList(filepath string) (map[string]struct{}, []netip.Prefix, error) {
if len(filepath) == 0 {
return map[string]struct{}{}, []*net.IPNet{}, nil
return nil, nil, nil
}
f, err := os.Open(filepath)
@@ -134,22 +133,24 @@ func parseExcludeList(filepath string) (map[string]struct{}, []*net.IPNet, error
// `ips` stores individual IPs to exclude, and `cidr` stores CIDR networks
// to exclude.
ips := make(map[string]struct{})
cidr := []*net.IPNet{}
cidr := []netip.Prefix{}
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
line := scanner.Text()
// Remove comments from text.
if i := strings.Index(line, "#"); i != -1 {
line = strings.TrimSpace(line[:i])
// Remove comments and trim.
if i := strings.IndexByte(line, '#'); i != -1 {
line = line[:i]
}
line = strings.TrimSpace(line)
if len(line) == 0 {
continue
}
if len(line) > 0 {
if _, ipnet, err := net.ParseCIDR(line); err == nil {
cidr = append(cidr, ipnet)
} else {
ips[line] = struct{}{}
}
if prefix, err := netip.ParsePrefix(line); err == nil {
cidr = append(cidr, prefix)
} else {
ips[line] = struct{}{}
}
}
if err := scanner.Err(); err != nil {
@@ -164,13 +165,13 @@ func (f feedEntries) applySort(method sortMethod, direction sortDirection) {
switch method {
case byIP:
slices.SortFunc(f, func(a, b feedEntry) int {
return bytes.Compare(a.IPBytes, b.IPBytes)
return a.IPBytes.Compare(b.IPBytes)
})
case byLastSeen:
slices.SortFunc(f, func(a, b feedEntry) int {
t := a.LastSeen.Compare(b.LastSeen)
if t == 0 {
return bytes.Compare(a.IPBytes, b.IPBytes)
return a.IPBytes.Compare(b.IPBytes)
}
return t
})
@@ -178,7 +179,7 @@ func (f feedEntries) applySort(method sortMethod, direction sortDirection) {
slices.SortFunc(f, func(a, b feedEntry) int {
t := a.Added.Compare(b.Added)
if t == 0 {
return bytes.Compare(a.IPBytes, b.IPBytes)
return a.IPBytes.Compare(b.IPBytes)
}
return t
})
@@ -186,7 +187,7 @@ func (f feedEntries) applySort(method sortMethod, direction sortDirection) {
slices.SortFunc(f, func(a, b feedEntry) int {
t := cmp.Compare(a.Observations, b.Observations)
if t == 0 {
return bytes.Compare(a.IPBytes, b.IPBytes)
return a.IPBytes.Compare(b.IPBytes)
}
return t
})

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"net"
"net/http"
"net/netip"
"sync"
"golang.org/x/net/websocket"
@@ -65,7 +66,7 @@ func handleWebSocket(ws *websocket.Conn) {
if err != nil {
return
}
if netIP := net.ParseIP(ip); !netIP.IsPrivate() && !netIP.IsLoopback() {
if parsedIP, err := netip.ParseAddr(ip); err != nil || (!parsedIP.IsPrivate() && !parsedIP.IsLoopback()) {
return
}
fmt.Println("[Threat Feed]", ip, "established WebSocket connection")

View File

@@ -3,6 +3,7 @@ package threatfeed
import (
"net"
"net/http"
"net/netip"
)
// enforcePrivateIP is a middleware that restricts access to the HTTP server
@@ -12,11 +13,11 @@ func enforcePrivateIP(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
http.Error(w, "Could not get IP", http.StatusInternalServerError)
http.Error(w, "", http.StatusForbidden)
return
}
if netIP := net.ParseIP(ip); !netIP.IsPrivate() && !netIP.IsLoopback() {
if parsedIP, err := netip.ParseAddr(ip); err != nil || (!parsedIP.IsPrivate() && !parsedIP.IsLoopback()) {
http.Error(w, "", http.StatusForbidden)
return
}