Allow Flow Routines to be cancellable (#40)

* Allow Flow Routines to be cancellable
This commit is contained in:
Mario Macias
2021-11-01 00:42:07 +01:00
committed by GitHub
parent 92043a6233
commit d1e1ace318
7 changed files with 259 additions and 37 deletions

View File

@@ -51,6 +51,8 @@ func (s *TemplateSystem) GetTemplate(version uint16, obsDomainId uint32, templat
}
type StateNetFlow struct {
stopper
Format format.FormatInterface
Transport transport.TransportInterface
Logger Logger
@@ -373,7 +375,10 @@ func (s *StateNetFlow) initConfig() {
}
func (s *StateNetFlow) FlowRoutine(workers int, addr string, port int, reuseport bool) error {
if err := s.start(); err != nil {
return err
}
s.InitTemplates()
s.initConfig()
return UDPRoutine("NetFlow", s.DecodeFlow, workers, addr, port, reuseport, s.Logger)
return UDPStoppableRoutine(s.stopCh, "NetFlow", s.DecodeFlow, workers, addr, port, reuseport, s.Logger)
}

View File

@@ -13,6 +13,8 @@ import (
)
type StateNFLegacy struct {
stopper
Format format.FormatInterface
Transport transport.TransportInterface
Logger Logger
@@ -95,5 +97,8 @@ func (s *StateNFLegacy) DecodeFlow(msg interface{}) error {
}
func (s *StateNFLegacy) FlowRoutine(workers int, addr string, port int, reuseport bool) error {
return UDPRoutine("NetFlowV5", s.DecodeFlow, workers, addr, port, reuseport, s.Logger)
if err := s.start(); err != nil {
return err
}
return UDPStoppableRoutine(s.stopCh, "NetFlowV5", s.DecodeFlow, workers, addr, port, reuseport, s.Logger)
}

View File

@@ -14,6 +14,8 @@ import (
)
type StateSFlow struct {
stopper
Format format.FormatInterface
Transport transport.TransportInterface
Logger Logger
@@ -153,6 +155,9 @@ func (s *StateSFlow) initConfig() {
}
func (s *StateSFlow) FlowRoutine(workers int, addr string, port int, reuseport bool) error {
if err := s.start(); err != nil {
return err
}
s.initConfig()
return UDPRoutine("sFlow", s.DecodeFlow, workers, addr, port, reuseport, s.Logger)
return UDPStoppableRoutine(s.stopCh, "sFlow", s.DecodeFlow, workers, addr, port, reuseport, s.Logger)
}

28
utils/stopper.go Normal file
View File

@@ -0,0 +1,28 @@
package utils
import (
"errors"
)
// ErrAlreadyStarted error happens when you try to start twice a flow routine
var ErrAlreadyStarted = errors.New("the routine is already started")
// stopper mechanism, common for all the flow routines
type stopper struct {
stopCh chan struct{}
}
func (s *stopper) start() error {
if s.stopCh != nil {
return ErrAlreadyStarted
}
s.stopCh = make(chan struct{})
return nil
}
func (s *stopper) Shutdown() {
if s.stopCh != nil {
close(s.stopCh)
s.stopCh = nil
}
}

51
utils/stopper_test.go Normal file
View File

@@ -0,0 +1,51 @@
package utils
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStopper(t *testing.T) {
r := routine{}
require.False(t, r.Running)
require.NoError(t, r.StartRoutine())
assert.True(t, r.Running)
r.Shutdown()
assert.Eventually(t, func() bool {
return r.Running == false
}, time.Second, time.Millisecond)
// after shutdown, we can start it again
require.NoError(t, r.StartRoutine())
assert.True(t, r.Running)
}
func TestStopper_CannotStartTwice(t *testing.T) {
r := routine{}
require.False(t, r.Running)
require.NoError(t, r.StartRoutine())
assert.ErrorIs(t, r.StartRoutine(), ErrAlreadyStarted)
}
type routine struct {
stopper
Running bool
}
func (p *routine) StartRoutine() error {
if err := p.start(); err != nil {
return err
}
p.Running = true
waitForGoRoutine := make(chan struct{})
go func() {
close(waitForGoRoutine)
<-p.stopCh
p.Running = false
}()
<-waitForGoRoutine
return nil
}

View File

@@ -6,6 +6,7 @@ import (
"io"
"net"
"strconv"
"sync/atomic"
"time"
reuseport "github.com/libp2p/go-reuseport"
@@ -99,6 +100,11 @@ func (cb *DefaultErrorCallback) Callback(name string, id int, start, end time.Ti
}
func UDPRoutine(name string, decodeFunc decoder.DecoderFunc, workers int, addr string, port int, sockReuse bool, logger Logger) error {
return UDPStoppableRoutine(make(chan struct{}), name, decodeFunc, workers, addr, port, sockReuse, logger)
}
// UDPStoppableRoutine runs a UDPRoutine that can be stopped by closing the stopCh passed as argument
func UDPStoppableRoutine(stopCh <-chan struct{}, name string, decodeFunc decoder.DecoderFunc, workers int, addr string, port int, sockReuse bool, logger Logger) error {
ecb := DefaultErrorCallback{
Logger: logger,
}
@@ -146,41 +152,71 @@ func UDPRoutine(name string, decodeFunc decoder.DecoderFunc, workers int, addr s
localIP = ""
}
for {
size, pktAddr, _ := udpconn.ReadFromUDP(payload)
payloadCut := make([]byte, size)
copy(payloadCut, payload[0:size])
type udpData struct {
size int
pktAddr *net.UDPAddr
}
baseMessage := BaseMessage{
Src: pktAddr.IP,
Port: pktAddr.Port,
Payload: payloadCut,
stopped := atomic.Value{}
stopped.Store(false)
udpDataCh := make(chan udpData)
go func() {
for {
u := udpData{}
u.size, u.pktAddr, _ = udpconn.ReadFromUDP(payload)
if stopped.Load() == false {
udpDataCh <- u
} else {
return
}
}
}()
for {
select {
case u := <-udpDataCh:
process(u.size, payload, u.pktAddr, processor, localIP, addrUDP, name)
case <-stopCh:
stopped.Store(true)
udpconn.Close()
close(udpDataCh)
return nil
}
processor.ProcessMessage(baseMessage)
MetricTrafficBytes.With(
prometheus.Labels{
"remote_ip": pktAddr.IP.String(),
"local_ip": localIP,
"local_port": strconv.Itoa(addrUDP.Port),
"type": name,
}).
Add(float64(size))
MetricTrafficPackets.With(
prometheus.Labels{
"remote_ip": pktAddr.IP.String(),
"local_ip": localIP,
"local_port": strconv.Itoa(addrUDP.Port),
"type": name,
}).
Inc()
MetricPacketSizeSum.With(
prometheus.Labels{
"remote_ip": pktAddr.IP.String(),
"local_ip": localIP,
"local_port": strconv.Itoa(addrUDP.Port),
"type": name,
}).
Observe(float64(size))
}
}
func process(size int, payload []byte, pktAddr *net.UDPAddr, processor decoder.Processor, localIP string, addrUDP net.UDPAddr, name string) {
payloadCut := make([]byte, size)
copy(payloadCut, payload[0:size])
baseMessage := BaseMessage{
Src: pktAddr.IP,
Port: pktAddr.Port,
Payload: payloadCut,
}
processor.ProcessMessage(baseMessage)
MetricTrafficBytes.With(
prometheus.Labels{
"remote_ip": pktAddr.IP.String(),
"local_ip": localIP,
"local_port": strconv.Itoa(addrUDP.Port),
"type": name,
}).
Add(float64(size))
MetricTrafficPackets.With(
prometheus.Labels{
"remote_ip": pktAddr.IP.String(),
"local_ip": localIP,
"local_port": strconv.Itoa(addrUDP.Port),
"type": name,
}).
Inc()
MetricPacketSizeSum.With(
prometheus.Labels{
"remote_ip": pktAddr.IP.String(),
"local_ip": localIP,
"local_port": strconv.Itoa(addrUDP.Port),
"type": name,
}).
Observe(float64(size))
}

92
utils/utils_test.go Normal file
View File

@@ -0,0 +1,92 @@
package utils
import (
"fmt"
"net"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCancelUDPRoutine(t *testing.T) {
testTimeout := time.After(10 * time.Second)
port, err := getFreeUDPPort()
require.NoError(t, err)
dp := dummyFlowProcessor{}
go func() {
require.NoError(t, dp.FlowRoutine("127.0.0.1", port))
}()
// wait slightly so we give time to the server to accept requests
time.Sleep(100 * time.Millisecond)
sendMessage := func(msg string) error {
conn, err := net.Dial("udp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
return err
}
defer conn.Close()
_, err = conn.Write([]byte(msg))
return err
}
require.NoError(t, sendMessage("message 1"))
require.NoError(t, sendMessage("message 2"))
require.NoError(t, sendMessage("message 3"))
readMessage := func() string {
select {
case msg := <-dp.receivedMessages:
return string(msg.(BaseMessage).Payload)
case <-testTimeout:
require.Fail(t, "test timed out while waiting for message")
return ""
}
}
// in UDP, messages might arrive out of order or duplicate, so whe just verify they arrive
// to avoid flaky tests
require.Contains(t, []string{"message 1", "message 2", "message 3"}, readMessage())
require.Contains(t, []string{"message 1", "message 2", "message 3"}, readMessage())
require.Contains(t, []string{"message 1", "message 2", "message 3"}, readMessage())
dp.Shutdown()
_ = sendMessage("no more messages should be processed")
select {
case msg := <-dp.receivedMessages:
assert.Fail(t, fmt.Sprint(msg))
default:
// everything is correct
}
}
type dummyFlowProcessor struct {
stopper
receivedMessages chan interface{}
}
func (d *dummyFlowProcessor) FlowRoutine(host string, port int) error {
_ = d.start()
d.receivedMessages = make(chan interface{})
return UDPStoppableRoutine(d.stopCh, "test_udp", func(msg interface{}) error {
d.receivedMessages <- msg
return nil
}, 3, host, port, false, logrus.StandardLogger())
}
func getFreeUDPPort() (int, error) {
a, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
return 0, err
}
l, err := net.ListenUDP("udp", a)
if err != nil {
return 0, err
}
defer l.Close()
return l.LocalAddr().(*net.UDPAddr).Port, nil
}