Refactor feed operations

- Add parseParams function to parse HTTP query parameters.
- Switch from functional options to an options struct for prepareFeed.
- Define a `feed` type for []net.IP.
- Add `applySort` method for sorting the feed.
- Change `convertToObservables` and `convertToIndicators` from functions to methods.
- Optimize exclude list processing.
  - parseExcludeList now determines and tracks whether each entry is a single IP or a CIDR block and now returns both a map of single IPs and a slice of CIDR blocks.
  - prepareFeed now calls parseExcludeList prior to iterating over iocData.
  - Delete the filterIPs function. All filtering logic is applied while iterating over iocData rather than afterwards.
This commit is contained in:
Ryan Smith
2024-11-20 08:31:20 -08:00
parent 716988a546
commit b7ed763661
2 changed files with 147 additions and 145 deletions

View File

@@ -3,6 +3,7 @@ package threatfeed
import (
"bufio"
"bytes"
"cmp"
"fmt"
"net"
"os"
@@ -13,182 +14,155 @@ import (
"github.com/r-smith/deceptifeed/internal/stix"
)
type feed []net.IP
// sortMethod is a type representing threat feed sorting methods.
type sortMethod int
const (
byIP sortMethod = iota
byLastSeen
byAdded
byThreatScore
)
// feedOptions define configurable options for serving the threat feed.
type feedOptions struct {
sortMethod sortMethod
seenAfter time.Time
}
// option defines a function type for configuring `feedOptions`.
type option func(*feedOptions)
// sortByLastSeen returns an option that sets the sort method in `feedOptions`
// to sort the threat feed by the last seen time.
func sortByLastSeen() option {
return func(o *feedOptions) {
o.sortMethod = byLastSeen
}
}
// seenAfter returns an option that sets the the `seenAfter` time in
// `feedOptions`. This filters the feed to include only entries seen after the
// specified timestamp.
func seenAfter(after time.Time) option {
return func(o *feedOptions) {
o.seenAfter = after
}
limit int
page int
}
// prepareFeed filters, processes, and sorts IP addresses from the threat feed.
// The resulting slice of `net.IP` represents the current threat feed to be
// served to clients.
func prepareFeed(options ...option) []net.IP {
func prepareFeed(options ...feedOptions) feed {
// Set default feed options.
opt := feedOptions{
sortMethod: byIP,
seenAfter: time.Time{},
}
for _, o := range options {
o(&opt)
// Override default options if provided.
if len(options) > 0 {
opt = options[0]
}
// Parse IPs from iocData to net.IP. Skip IPs that are expired, below the
// minimum threat score, or are private, based on the configuration.
excludedIPs, excludedCIDR, err := parseExcludeList(configuration.ExcludeListPath)
if err != nil {
fmt.Fprintln(os.Stderr, "Failed to read threat feed exclude list:", err)
}
// Parse and filter IPs from iocData into the threat feed.
mutex.Lock()
netIPs := make([]net.IP, 0, len(iocData))
threats := make(feed, 0, len(iocData))
loop:
for ip, ioc := range iocData {
if ioc.expired() || ioc.ThreatScore < configuration.MinimumThreatScore || !ioc.LastSeen.After(opt.seenAfter) {
continue
}
ipParsed := net.ParseIP(ip)
if ipParsed == nil {
parsedIP := net.ParseIP(ip)
if parsedIP == nil || (parsedIP.IsPrivate() && !configuration.IsPrivateIncluded) {
continue
}
if !configuration.IsPrivateIncluded && ipParsed.IsPrivate() {
for _, ipnet := range excludedCIDR {
if ipnet.Contains(parsedIP) {
continue loop
}
}
if _, found := excludedIPs[ip]; found {
continue
}
netIPs = append(netIPs, ipParsed)
threats = append(threats, parsedIP)
}
mutex.Unlock()
// If an exclude list is provided, filter the IP list.
if len(configuration.ExcludeListPath) > 0 {
ipsToRemove, err := parseExcludeList(configuration.ExcludeListPath)
if err != nil {
fmt.Fprintln(os.Stderr, "Failed to read threat feed exclude list:", err)
} else {
netIPs = filterIPs(netIPs, ipsToRemove)
}
}
threats.applySort(opt.sortMethod)
// Apply sorting.
switch opt.sortMethod {
case byIP:
slices.SortFunc(netIPs, func(a, b net.IP) int {
return bytes.Compare(a, b)
})
case byLastSeen:
mutex.Lock()
slices.SortFunc(netIPs, func(a, b net.IP) int {
// Sort by LastSeen date, and if equal, sort by IP.
dateCompare := iocData[a.String()].LastSeen.Compare(iocData[b.String()].LastSeen)
if dateCompare != 0 {
return dateCompare
}
return bytes.Compare(a, b)
})
mutex.Unlock()
}
return netIPs
return threats
}
// parseExcludeList reads IP addresses and CIDR ranges from a file. Each line
// should contain an IP address or CIDR. It returns a map of the unique IPs and
// CIDR ranges found in the file.
func parseExcludeList(filepath string) (map[string]struct{}, error) {
ips := make(map[string]struct{})
// a slice of the CIDR ranges found in the file.
func parseExcludeList(filepath string) (map[string]struct{}, []*net.IPNet, error) {
if len(filepath) == 0 {
return map[string]struct{}{}, []*net.IPNet{}, nil
}
file, err := os.Open(filepath)
if err != nil {
return nil, err
return nil, nil, err
}
defer file.Close()
// `ips` stores individual IPs to exclude, and `cidr` stores CIDR networks
// to exclude.
ips := make(map[string]struct{})
cidr := []*net.IPNet{}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if len(line) > 0 {
ips[line] = struct{}{}
if _, ipnet, err := net.ParseCIDR(line); err == nil {
cidr = append(cidr, ipnet)
} else {
ips[line] = struct{}{}
}
}
}
if err := scanner.Err(); err != nil {
return nil, err
return nil, nil, err
}
return ips, nil
return ips, cidr, nil
}
// filterIPs removes IPs from ipList that are found in the ipsToRemove map. The
// keys in ipsToRemove may be single IP addresses or CIDR ranges. If a key is a
// CIDR range, an IP will be removed if it falls within that range.
func filterIPs(ipList []net.IP, ipsToRemove map[string]struct{}) []net.IP {
if len(ipsToRemove) == 0 {
return ipList
func (f feed) applySort(method sortMethod) {
switch method {
case byIP:
slices.SortFunc(f, func(a, b net.IP) int {
return bytes.Compare(a, b)
})
case byLastSeen:
mutex.Lock()
slices.SortFunc(f, func(a, b net.IP) int {
return iocData[a.String()].LastSeen.Compare(iocData[b.String()].LastSeen)
})
mutex.Unlock()
case byAdded:
mutex.Lock()
slices.SortFunc(f, func(a, b net.IP) int {
return iocData[a.String()].Added.Compare(iocData[b.String()].Added)
})
mutex.Unlock()
case byThreatScore:
mutex.Lock()
slices.SortFunc(f, func(a, b net.IP) int {
return cmp.Compare(iocData[a.String()].ThreatScore, iocData[b.String()].ThreatScore)
})
mutex.Unlock()
}
cidrNetworks := []*net.IPNet{}
for cidr := range ipsToRemove {
if _, ipnet, err := net.ParseCIDR(cidr); err == nil {
cidrNetworks = append(cidrNetworks, ipnet)
}
}
i := 0
for _, ip := range ipList {
if _, found := ipsToRemove[ip.String()]; found {
continue
}
contains := false
for _, ipnet := range cidrNetworks {
if ipnet.Contains(ip) {
contains = true
break
}
}
if !contains {
ipList[i] = ip
i++
}
}
return ipList[:i]
}
// convertToIndicators converts IP addresses from the threat feed into a
// collection of STIX Indicator objects.
func convertToIndicators(ips []net.IP) []stix.Object {
if len(ips) == 0 {
func (f feed) convertToIndicators() []stix.Object {
if len(f) == 0 {
return []stix.Object{}
}
const indicator = "indicator"
result := make([]stix.Object, 0, len(ips)+1)
result := make([]stix.Object, 0, len(f)+1)
// Add the Deceptifeed `Identity` as the first object in the collection.
// All IP addresses in the collection will reference this identity as
// the creator.
result = append(result, stix.DeceptifeedIdentity())
for _, ip := range ips {
for _, ip := range f {
if ioc, found := iocData[ip.String()]; found {
pattern := "[ipv4-addr:value = '"
if strings.Contains(ip.String(), ":") {
@@ -230,19 +204,19 @@ func convertToIndicators(ips []net.IP) []stix.Object {
// convertToObservables converts IP addresses from the threat feed into a
// collection of STIX Cyber-observable Objects.
func convertToObservables(ips []net.IP) []stix.Object {
if len(ips) == 0 {
func (f feed) convertToObservables() []stix.Object {
if len(f) == 0 {
return []stix.Object{}
}
result := make([]stix.Object, 0, len(ips)+1)
result := make([]stix.Object, 0, len(f)+1)
// Add the Deceptifeed `Identity` as the first object in the collection.
// All IP addresses in the collection will reference this identity as
// the creator.
result = append(result, stix.DeceptifeedIdentity())
for _, ip := range ips {
for _, ip := range f {
if _, found := iocData[ip.String()]; found {
t := "ipv4-addr"
if strings.Contains(ip.String(), ":") {

View File

@@ -7,6 +7,7 @@ import (
"net/http"
"os"
"strconv"
"strings"
"time"
"github.com/r-smith/deceptifeed/internal/stix"
@@ -154,7 +155,7 @@ func handleSTIX2(w http.ResponseWriter, r *http.Request) {
result := stix.Bundle{
Type: bundle,
ID: stix.NewID(bundle),
Objects: convertToIndicators(prepareFeed()),
Objects: prepareFeed().convertToIndicators(),
}
w.Header().Set("Content-Type", stix.ContentType)
@@ -172,7 +173,7 @@ func handleSTIX2Simple(w http.ResponseWriter, r *http.Request) {
result := stix.Bundle{
Type: bundle,
ID: stix.NewID(bundle),
Objects: convertToObservables(prepareFeed()),
Objects: prepareFeed().convertToObservables(),
}
w.Header().Set("Content-Type", stix.ContentType)
@@ -265,47 +266,24 @@ func handleTAXIICollections(w http.ResponseWriter, r *http.Request) {
// structured according to the requested TAXII collection and wrapped in a
// TAXII Envelope. Request URL format: `{api-root}/collections/{id}/objects/`.
func handleTAXIIObjects(w http.ResponseWriter, r *http.Request) {
// Set default values.
after := time.Time{}
limit := 0
page := 0
var err error
// Parse the URL query parameters.
if len(r.URL.Query().Get("added_after")) > 0 {
after, err = time.Parse(time.RFC3339, r.URL.Query().Get("added_after"))
if err != nil {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
}
if len(r.URL.Query().Get("limit")) > 0 {
limit, err = strconv.Atoi(r.URL.Query().Get("limit"))
if err != nil {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
}
if len(r.URL.Query().Get("next")) > 0 {
page, err = strconv.Atoi(r.URL.Query().Get("next"))
if err != nil {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
opt, err := parseParams(r)
if err != nil {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
// Ensure a minimum page number of 1.
if page < 1 {
page = 1
if opt.page < 1 {
opt.page = 1
}
// Build the requested collection.
result := taxii.Envelope{}
switch r.PathValue("id") {
case taxii.IndicatorsID, taxii.IndicatorsAlias:
result.Objects = convertToIndicators(prepareFeed(sortByLastSeen(), seenAfter(after)))
result.Objects = prepareFeed(opt).convertToIndicators()
case taxii.ObservablesID, taxii.ObservablesAlias:
result.Objects = convertToObservables(prepareFeed(sortByLastSeen(), seenAfter(after)))
result.Objects = prepareFeed(opt).convertToObservables()
default:
handleTAXIINotFound(w, r)
return
@@ -313,13 +291,13 @@ func handleTAXIIObjects(w http.ResponseWriter, r *http.Request) {
// Paginate. result.Objects may be resliced depending on the requested
// limit and page number.
result.Objects, result.More = paginate(result.Objects, limit, page)
result.Objects, result.More = paginate(result.Objects, opt.limit, opt.page)
// If more results are available, include the `next` property in the
// response with the next page number.
if result.More {
if page+1 > 0 {
result.Next = strconv.Itoa(page + 1)
if opt.page+1 > 0 {
result.Next = strconv.Itoa(opt.page + 1)
}
}
@@ -393,6 +371,56 @@ func paginate(items []stix.Object, limit int, page int) ([]stix.Object, bool) {
return items[start:end], more
}
// parseParams extracts HTTP query parameters and maps them to options for
// controlling the threat feed output.
func parseParams(r *http.Request) (feedOptions, error) {
opt := feedOptions{}
// Handle TAXII parameters.
if strings.HasPrefix(r.URL.Path, taxii.APIRoot) {
// TAXII requires results to be sorted by object creation date.
// However, since IPs in the threat feed may have their `LastSeen` date
// updated after being added, it makes more sense to sort by the last
// seen date instead. Otherwise, clients may miss updates if they are
// only looking for newly added results.
opt.sortMethod = byLastSeen
var err error
if len(r.URL.Query().Get("added_after")) > 0 {
opt.seenAfter, err = time.Parse(time.RFC3339, r.URL.Query().Get("added_after"))
if err != nil {
return feedOptions{}, err
}
}
if len(r.URL.Query().Get("limit")) > 0 {
opt.limit, err = strconv.Atoi(r.URL.Query().Get("limit"))
if err != nil {
return feedOptions{}, err
}
}
if len(r.URL.Query().Get("next")) > 0 {
opt.page, err = strconv.Atoi(r.URL.Query().Get("next"))
if err != nil {
return feedOptions{}, err
}
}
return opt, nil
}
switch r.URL.Query().Get("sort") {
case "last_seen":
opt.sortMethod = byLastSeen
case "added":
opt.sortMethod = byAdded
case "threat_score":
opt.sortMethod = byThreatScore
default:
opt.sortMethod = byIP
}
return opt, nil
}
// handleEmpty handles HTTP requests to /empty. It returns an empty body with
// status code 200. This endpoint is useful for temporarily clearing the threat
// feed data in firewalls.