mirror of
https://github.com/openobserve/goflow2.git
synced 2025-11-02 04:53:27 +00:00
Allow Flow Routines to be cancellable (#40)
* Allow Flow Routines to be cancellable
This commit is contained in:
@@ -51,6 +51,8 @@ func (s *TemplateSystem) GetTemplate(version uint16, obsDomainId uint32, templat
|
||||
}
|
||||
|
||||
type StateNetFlow struct {
|
||||
stopper
|
||||
|
||||
Format format.FormatInterface
|
||||
Transport transport.TransportInterface
|
||||
Logger Logger
|
||||
@@ -373,7 +375,10 @@ func (s *StateNetFlow) initConfig() {
|
||||
}
|
||||
|
||||
func (s *StateNetFlow) FlowRoutine(workers int, addr string, port int, reuseport bool) error {
|
||||
if err := s.start(); err != nil {
|
||||
return err
|
||||
}
|
||||
s.InitTemplates()
|
||||
s.initConfig()
|
||||
return UDPRoutine("NetFlow", s.DecodeFlow, workers, addr, port, reuseport, s.Logger)
|
||||
return UDPStoppableRoutine(s.stopCh, "NetFlow", s.DecodeFlow, workers, addr, port, reuseport, s.Logger)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
)
|
||||
|
||||
type StateNFLegacy struct {
|
||||
stopper
|
||||
|
||||
Format format.FormatInterface
|
||||
Transport transport.TransportInterface
|
||||
Logger Logger
|
||||
@@ -95,5 +97,8 @@ func (s *StateNFLegacy) DecodeFlow(msg interface{}) error {
|
||||
}
|
||||
|
||||
func (s *StateNFLegacy) FlowRoutine(workers int, addr string, port int, reuseport bool) error {
|
||||
return UDPRoutine("NetFlowV5", s.DecodeFlow, workers, addr, port, reuseport, s.Logger)
|
||||
if err := s.start(); err != nil {
|
||||
return err
|
||||
}
|
||||
return UDPStoppableRoutine(s.stopCh, "NetFlowV5", s.DecodeFlow, workers, addr, port, reuseport, s.Logger)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
)
|
||||
|
||||
type StateSFlow struct {
|
||||
stopper
|
||||
|
||||
Format format.FormatInterface
|
||||
Transport transport.TransportInterface
|
||||
Logger Logger
|
||||
@@ -153,6 +155,9 @@ func (s *StateSFlow) initConfig() {
|
||||
}
|
||||
|
||||
func (s *StateSFlow) FlowRoutine(workers int, addr string, port int, reuseport bool) error {
|
||||
if err := s.start(); err != nil {
|
||||
return err
|
||||
}
|
||||
s.initConfig()
|
||||
return UDPRoutine("sFlow", s.DecodeFlow, workers, addr, port, reuseport, s.Logger)
|
||||
return UDPStoppableRoutine(s.stopCh, "sFlow", s.DecodeFlow, workers, addr, port, reuseport, s.Logger)
|
||||
}
|
||||
|
||||
28
utils/stopper.go
Normal file
28
utils/stopper.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// ErrAlreadyStarted error happens when you try to start twice a flow routine
|
||||
var ErrAlreadyStarted = errors.New("the routine is already started")
|
||||
|
||||
// stopper mechanism, common for all the flow routines
|
||||
type stopper struct {
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
func (s *stopper) start() error {
|
||||
if s.stopCh != nil {
|
||||
return ErrAlreadyStarted
|
||||
}
|
||||
s.stopCh = make(chan struct{})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stopper) Shutdown() {
|
||||
if s.stopCh != nil {
|
||||
close(s.stopCh)
|
||||
s.stopCh = nil
|
||||
}
|
||||
}
|
||||
51
utils/stopper_test.go
Normal file
51
utils/stopper_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStopper(t *testing.T) {
|
||||
r := routine{}
|
||||
require.False(t, r.Running)
|
||||
require.NoError(t, r.StartRoutine())
|
||||
assert.True(t, r.Running)
|
||||
r.Shutdown()
|
||||
assert.Eventually(t, func() bool {
|
||||
return r.Running == false
|
||||
}, time.Second, time.Millisecond)
|
||||
|
||||
// after shutdown, we can start it again
|
||||
require.NoError(t, r.StartRoutine())
|
||||
assert.True(t, r.Running)
|
||||
}
|
||||
|
||||
func TestStopper_CannotStartTwice(t *testing.T) {
|
||||
r := routine{}
|
||||
require.False(t, r.Running)
|
||||
require.NoError(t, r.StartRoutine())
|
||||
assert.ErrorIs(t, r.StartRoutine(), ErrAlreadyStarted)
|
||||
}
|
||||
|
||||
type routine struct {
|
||||
stopper
|
||||
Running bool
|
||||
}
|
||||
|
||||
func (p *routine) StartRoutine() error {
|
||||
if err := p.start(); err != nil {
|
||||
return err
|
||||
}
|
||||
p.Running = true
|
||||
waitForGoRoutine := make(chan struct{})
|
||||
go func() {
|
||||
close(waitForGoRoutine)
|
||||
<-p.stopCh
|
||||
p.Running = false
|
||||
}()
|
||||
<-waitForGoRoutine
|
||||
return nil
|
||||
}
|
||||
104
utils/utils.go
104
utils/utils.go
@@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
reuseport "github.com/libp2p/go-reuseport"
|
||||
@@ -99,6 +100,11 @@ func (cb *DefaultErrorCallback) Callback(name string, id int, start, end time.Ti
|
||||
}
|
||||
|
||||
func UDPRoutine(name string, decodeFunc decoder.DecoderFunc, workers int, addr string, port int, sockReuse bool, logger Logger) error {
|
||||
return UDPStoppableRoutine(make(chan struct{}), name, decodeFunc, workers, addr, port, sockReuse, logger)
|
||||
}
|
||||
|
||||
// UDPStoppableRoutine runs a UDPRoutine that can be stopped by closing the stopCh passed as argument
|
||||
func UDPStoppableRoutine(stopCh <-chan struct{}, name string, decodeFunc decoder.DecoderFunc, workers int, addr string, port int, sockReuse bool, logger Logger) error {
|
||||
ecb := DefaultErrorCallback{
|
||||
Logger: logger,
|
||||
}
|
||||
@@ -146,41 +152,71 @@ func UDPRoutine(name string, decodeFunc decoder.DecoderFunc, workers int, addr s
|
||||
localIP = ""
|
||||
}
|
||||
|
||||
for {
|
||||
size, pktAddr, _ := udpconn.ReadFromUDP(payload)
|
||||
payloadCut := make([]byte, size)
|
||||
copy(payloadCut, payload[0:size])
|
||||
type udpData struct {
|
||||
size int
|
||||
pktAddr *net.UDPAddr
|
||||
}
|
||||
|
||||
baseMessage := BaseMessage{
|
||||
Src: pktAddr.IP,
|
||||
Port: pktAddr.Port,
|
||||
Payload: payloadCut,
|
||||
stopped := atomic.Value{}
|
||||
stopped.Store(false)
|
||||
udpDataCh := make(chan udpData)
|
||||
go func() {
|
||||
for {
|
||||
u := udpData{}
|
||||
u.size, u.pktAddr, _ = udpconn.ReadFromUDP(payload)
|
||||
if stopped.Load() == false {
|
||||
udpDataCh <- u
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case u := <-udpDataCh:
|
||||
process(u.size, payload, u.pktAddr, processor, localIP, addrUDP, name)
|
||||
case <-stopCh:
|
||||
stopped.Store(true)
|
||||
udpconn.Close()
|
||||
close(udpDataCh)
|
||||
return nil
|
||||
}
|
||||
processor.ProcessMessage(baseMessage)
|
||||
|
||||
MetricTrafficBytes.With(
|
||||
prometheus.Labels{
|
||||
"remote_ip": pktAddr.IP.String(),
|
||||
"local_ip": localIP,
|
||||
"local_port": strconv.Itoa(addrUDP.Port),
|
||||
"type": name,
|
||||
}).
|
||||
Add(float64(size))
|
||||
MetricTrafficPackets.With(
|
||||
prometheus.Labels{
|
||||
"remote_ip": pktAddr.IP.String(),
|
||||
"local_ip": localIP,
|
||||
"local_port": strconv.Itoa(addrUDP.Port),
|
||||
"type": name,
|
||||
}).
|
||||
Inc()
|
||||
MetricPacketSizeSum.With(
|
||||
prometheus.Labels{
|
||||
"remote_ip": pktAddr.IP.String(),
|
||||
"local_ip": localIP,
|
||||
"local_port": strconv.Itoa(addrUDP.Port),
|
||||
"type": name,
|
||||
}).
|
||||
Observe(float64(size))
|
||||
}
|
||||
}
|
||||
|
||||
func process(size int, payload []byte, pktAddr *net.UDPAddr, processor decoder.Processor, localIP string, addrUDP net.UDPAddr, name string) {
|
||||
payloadCut := make([]byte, size)
|
||||
copy(payloadCut, payload[0:size])
|
||||
|
||||
baseMessage := BaseMessage{
|
||||
Src: pktAddr.IP,
|
||||
Port: pktAddr.Port,
|
||||
Payload: payloadCut,
|
||||
}
|
||||
processor.ProcessMessage(baseMessage)
|
||||
|
||||
MetricTrafficBytes.With(
|
||||
prometheus.Labels{
|
||||
"remote_ip": pktAddr.IP.String(),
|
||||
"local_ip": localIP,
|
||||
"local_port": strconv.Itoa(addrUDP.Port),
|
||||
"type": name,
|
||||
}).
|
||||
Add(float64(size))
|
||||
MetricTrafficPackets.With(
|
||||
prometheus.Labels{
|
||||
"remote_ip": pktAddr.IP.String(),
|
||||
"local_ip": localIP,
|
||||
"local_port": strconv.Itoa(addrUDP.Port),
|
||||
"type": name,
|
||||
}).
|
||||
Inc()
|
||||
MetricPacketSizeSum.With(
|
||||
prometheus.Labels{
|
||||
"remote_ip": pktAddr.IP.String(),
|
||||
"local_ip": localIP,
|
||||
"local_port": strconv.Itoa(addrUDP.Port),
|
||||
"type": name,
|
||||
}).
|
||||
Observe(float64(size))
|
||||
}
|
||||
|
||||
92
utils/utils_test.go
Normal file
92
utils/utils_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCancelUDPRoutine(t *testing.T) {
|
||||
testTimeout := time.After(10 * time.Second)
|
||||
port, err := getFreeUDPPort()
|
||||
require.NoError(t, err)
|
||||
dp := dummyFlowProcessor{}
|
||||
go func() {
|
||||
require.NoError(t, dp.FlowRoutine("127.0.0.1", port))
|
||||
}()
|
||||
|
||||
// wait slightly so we give time to the server to accept requests
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
sendMessage := func(msg string) error {
|
||||
conn, err := net.Dial("udp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
_, err = conn.Write([]byte(msg))
|
||||
return err
|
||||
}
|
||||
require.NoError(t, sendMessage("message 1"))
|
||||
require.NoError(t, sendMessage("message 2"))
|
||||
require.NoError(t, sendMessage("message 3"))
|
||||
|
||||
readMessage := func() string {
|
||||
select {
|
||||
case msg := <-dp.receivedMessages:
|
||||
return string(msg.(BaseMessage).Payload)
|
||||
case <-testTimeout:
|
||||
require.Fail(t, "test timed out while waiting for message")
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// in UDP, messages might arrive out of order or duplicate, so whe just verify they arrive
|
||||
// to avoid flaky tests
|
||||
require.Contains(t, []string{"message 1", "message 2", "message 3"}, readMessage())
|
||||
require.Contains(t, []string{"message 1", "message 2", "message 3"}, readMessage())
|
||||
require.Contains(t, []string{"message 1", "message 2", "message 3"}, readMessage())
|
||||
|
||||
dp.Shutdown()
|
||||
|
||||
_ = sendMessage("no more messages should be processed")
|
||||
|
||||
select {
|
||||
case msg := <-dp.receivedMessages:
|
||||
assert.Fail(t, fmt.Sprint(msg))
|
||||
default:
|
||||
// everything is correct
|
||||
}
|
||||
}
|
||||
|
||||
type dummyFlowProcessor struct {
|
||||
stopper
|
||||
receivedMessages chan interface{}
|
||||
}
|
||||
|
||||
func (d *dummyFlowProcessor) FlowRoutine(host string, port int) error {
|
||||
_ = d.start()
|
||||
d.receivedMessages = make(chan interface{})
|
||||
return UDPStoppableRoutine(d.stopCh, "test_udp", func(msg interface{}) error {
|
||||
d.receivedMessages <- msg
|
||||
return nil
|
||||
}, 3, host, port, false, logrus.StandardLogger())
|
||||
}
|
||||
|
||||
func getFreeUDPPort() (int, error) {
|
||||
a, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
l, err := net.ListenUDP("udp", a)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer l.Close()
|
||||
return l.LocalAddr().(*net.UDPAddr).Port, nil
|
||||
}
|
||||
Reference in New Issue
Block a user