问题所在
这种改变是必要的.在最初的实现中,closedProxies
保存相同的 map .查看此演示:
package main
import "fmt"
func main() {
proxies := make(map[int]int, 0)
for i := 0; i < 10; i++ {
proxies[i] = i
}
closeProxies := proxies
proxies[10] = 10
proxies[11] = 11
for k := range closeProxies {
delete(proxies, k)
}
fmt.Printf("items left: %d\n", len(proxies))
// Output:
// items left: 0
}
但这并不是根本原因.可以在复制closeProxies
之后但在调用SetSwitchOver
之前添加新的代理.在这种情况下,新代理连接到旧地址,但不在closeProxies
中.我认为这是根本原因.
还有另一个问题.在设置To
字段之前,将新代理添加到proxies
.可能发生的情况是,程序希望在设置To
字段之前关闭此代理,从而导致死机.
工作可靠的设计
其 idea 是将所有端点放入一个切片中,并让每个端点管理其自己的代理列表.因此,我们只需要跟踪当前端点的索引.当我们想要切换到另一个端点时,我们只需要更改索引,并告诉过时的端点清除其代理.剩下的唯一复杂的事情是确保过时的端点可以清除其所有代理.具体实现见下图:
manager.go
这就是这一 idea 的实施.
package main
import (
"sync"
)
// Conn is abstraction of a connection to make Manager easy to test.
type Conn interface {
Close() error
}
// Dialer is abstraction of a dialer to make Manager easy to test.
type Dialer interface {
Dial(addr string) (Conn, error)
}
type Manager struct {
// muCurrent protects the "current" member.
muCurrent sync.RWMutex
current int // When current is -1, the manager is shuted down.
endpoints []*endpoint
// mu protects the whole Switch action.
mu sync.Mutex
}
func NewManager(dialer Dialer, addresses ...string) *Manager {
if len(addresses) < 2 {
panic("a manger should handle at least 2 addresses")
}
endpoints := make([]*endpoint, len(addresses))
for i, addr := range addresses {
endpoints[i] = &endpoint{
address: addr,
dialer: dialer,
}
}
return &Manager{
endpoints: endpoints,
}
}
func (m *Manager) AddProxy(from Conn) {
// 1. AddProxy will wait when the write lock of m.muCurrent is taken.
// Once the write lock is released, AddProxy will connect to the new endpoint.
// Switch only holds the write lock for a short time, and Switch is called
// not so frequently, so AddProxy won't wait too much.
// 2. Switch will wait if there is any AddProxy holding the read lock of
// m.muCurrent. That means Switch waits longer. The advantage is that when
// e.clear is called in Switch, All AddProxy requests to the old endpoint
// are done. So it's safe to call e.clear then.
m.muCurrent.RLock()
defer m.muCurrent.RUnlock()
current := m.current
// Do not accept any new connection when m has been shutdown.
if current == -1 {
from.Close()
return
}
m.endpoints[current].addProxy(from)
}
func (m *Manager) Switch() {
// In a real world, Switch is called not so frequently.
// So it's ok to add a lock here.
// And it's necessary to make sure the old endpoint is cleared and ready
// for use in the future.
m.mu.Lock()
defer m.mu.Unlock()
// Take the write lock of m.muCurrent.
// It waits for all the AddProxy requests holding the read lock to finish.
m.muCurrent.Lock()
old := m.current
// Do nothing when m has been shutdown.
if old == -1 {
m.muCurrent.Unlock()
return
}
next := old + 1
if next >= len(m.endpoints) {
next = 0
}
m.current = next
m.muCurrent.Unlock()
// When it reaches here, all AddProxy requests to the old endpoint are done.
// And it's safe to call e.clear now.
m.endpoints[old].clear()
}
func (m *Manager) Shutdown() {
m.mu.Lock()
defer m.mu.Unlock()
m.muCurrent.Lock()
current := m.current
m.current = -1
m.muCurrent.Unlock()
m.endpoints[current].clear()
}
type proxy struct {
from, to Conn
}
type endpoint struct {
address string
dialer Dialer
mu sync.Mutex
proxies []*proxy
}
func (e *endpoint) clear() {
for _, p := range e.proxies {
p.from.Close()
p.to.Close()
}
// Assign a new slice to e.proxies, and the GC will collect the old one.
e.proxies = []*proxy{}
}
func (e *endpoint) addProxy(from Conn) {
toConn, err := e.dialer.Dial(e.address)
if err != nil {
// Close the from connection so that the client will reconnect?
from.Close()
return
}
e.mu.Lock()
defer e.mu.Unlock()
e.proxies = append(e.proxies, &proxy{from: from, to: toConn})
}
main.go
本演示演示如何使用前面实现的Manager类型:
package main
import (
"net"
"time"
)
type realDialer struct{}
func (d realDialer) Dial(addr string) (Conn, error) {
tcpAddr, err := net.ResolveTCPAddr("tcp4", addr)
if err != nil {
return nil, err
}
return net.DialTCP("tcp", nil, tcpAddr)
}
func main() {
manager := NewManager(realDialer{}, "1.1.1.1", "8.8.8.8")
tcpAddr, _ := net.ResolveTCPAddr("tcp4", "0.0.0.0:5432")
ln, _ := net.ListenTCP("tcp", tcpAddr)
go func() {
for range time.Tick(30 * time.Second) {
manager.Switch()
}
}()
for {
clientConn, err := ln.AcceptTCP()
if err != nil {
panic(err)
}
go manager.AddProxy(clientConn)
}
}
manager_test.go
使用以下命令运行测试:go test ./... -race -count 10
package main
import (
"errors"
"math/rand"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
)
func TestManager(t *testing.T) {
addresses := []string{"1.1.1.1", "8.8.8.8"}
dialer := newDialer(addresses...)
manager := NewManager(dialer, addresses...)
ch := make(chan int, 1)
var wg sync.WaitGroup
wg.Add(1)
go func() {
for range ch {
manager.Switch()
}
wg.Done()
}()
count := 1000
total := count * 10
wg.Add(total)
fromConn := &fakeFromConn{}
for i := 0; i < total; i++ {
if i%count == count-1 {
ch <- 0
}
go func() {
manager.AddProxy(fromConn)
wg.Done()
}()
}
close(ch)
wg.Wait()
manager.Shutdown()
for _, s := range dialer.servers {
left := len(s.conns)
if left != 0 {
t.Errorf("server %s, unexpected connections left: %d", s.addr, left)
}
}
closedCount := fromConn.closedCount.Load()
if closedCount != int32(total) {
t.Errorf("want closed count: %d, got: %d", total, closedCount)
}
}
type fakeFromConn struct {
closedCount atomic.Int32
}
func (c *fakeFromConn) Close() error {
c.closedCount.Add(1)
return nil
}
type fakeToConn struct {
id uuid.UUID
server *fakeServer
}
func (c *fakeToConn) Close() error {
if c.id == uuid.Nil {
return nil
}
c.server.removeConn(c.id)
return nil
}
type fakeServer struct {
addr string
mu sync.Mutex
conns map[uuid.UUID]bool
}
func (s *fakeServer) addConn() (uuid.UUID, error) {
s.mu.Lock()
defer s.mu.Unlock()
id, err := uuid.NewRandom()
if err == nil {
s.conns[id] = true
}
return id, err
}
func (s *fakeServer) removeConn(id uuid.UUID) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.conns, id)
}
type fakeDialer struct {
servers map[string]*fakeServer
}
func newDialer(addresses ...string) *fakeDialer {
servers := make(map[string]*fakeServer)
for _, addr := range addresses {
servers[addr] = &fakeServer{
addr: addr,
conns: make(map[uuid.UUID]bool),
}
}
return &fakeDialer{
servers: servers,
}
}
func (d *fakeDialer) Dial(addr string) (Conn, error) {
n := rand.Intn(100)
if n == 0 {
return nil, errors.New("fake network error")
}
// Simulate network latency.
time.Sleep(time.Duration(n) * time.Millisecond)
s := d.servers[addr]
id, err := s.addConn()
if err != nil {
return nil, err
}
conn := &fakeToConn{
id: id,
server: s,
}
return conn, nil
}