]> Sergey Matveev's repositories - btrtrc.git/blob - rlreader_test.go
Drop support for go 1.20
[btrtrc.git] / rlreader_test.go
1 package torrent
2
3 import (
4         "io"
5         "log"
6         "math/rand"
7         "sync"
8         "testing"
9         "time"
10
11         "github.com/stretchr/testify/assert"
12         "github.com/stretchr/testify/require"
13         "golang.org/x/time/rate"
14 )
15
16 func writeN(ws []io.Writer, n int) error {
17         b := make([]byte, n)
18         for _, w := range ws[1:] {
19                 n1 := rand.Intn(n)
20                 wn, err := w.Write(b[:n1])
21                 if wn != n1 {
22                         if err == nil {
23                                 panic(n1)
24                         }
25                         return err
26                 }
27                 n -= n1
28         }
29         wn, err := ws[0].Write(b[:n])
30         if wn != n {
31                 if err == nil {
32                         panic(n)
33                 }
34         }
35         return err
36 }
37
38 func TestRateLimitReaders(t *testing.T) {
39         const (
40                 numReaders     = 2
41                 bytesPerSecond = 100
42                 burst          = 5
43                 readSize       = 6
44                 writeRounds    = 10
45                 bytesPerRound  = 12
46         )
47         control := rate.NewLimiter(bytesPerSecond, burst)
48         shared := rate.NewLimiter(bytesPerSecond, burst)
49         var (
50                 ws []io.Writer
51                 cs []io.Closer
52         )
53         wg := sync.WaitGroup{}
54         type read struct {
55                 N int
56                 // When the read was allowed.
57                 At time.Time
58         }
59         reads := make(chan read)
60         done := make(chan struct{})
61         for i := 0; i < numReaders; i += 1 {
62                 r, w := io.Pipe()
63                 ws = append(ws, w)
64                 cs = append(cs, w)
65                 wg.Add(1)
66                 go func() {
67                         defer wg.Done()
68                         r := rateLimitedReader{
69                                 l: shared,
70                                 r: r,
71                         }
72                         b := make([]byte, readSize)
73                         for {
74                                 n, err := r.Read(b)
75                                 select {
76                                 case reads <- read{n, r.lastRead}:
77                                 case <-done:
78                                         return
79                                 }
80                                 if err == io.EOF {
81                                         return
82                                 }
83                                 if err != nil {
84                                         panic(err)
85                                 }
86                         }
87                 }()
88         }
89         closeAll := func() {
90                 for _, c := range cs {
91                         c.Close()
92                 }
93         }
94         defer func() {
95                 close(done)
96                 closeAll()
97                 wg.Wait()
98         }()
99         written := 0
100         go func() {
101                 for i := 0; i < writeRounds; i += 1 {
102                         err := writeN(ws, bytesPerRound)
103                         if err != nil {
104                                 log.Printf("error writing: %s", err)
105                                 break
106                         }
107                         written += bytesPerRound
108                 }
109                 closeAll()
110                 wg.Wait()
111                 close(reads)
112         }()
113         totalBytesRead := 0
114         started := time.Now()
115         for r := range reads {
116                 totalBytesRead += r.N
117                 require.False(t, r.At.IsZero())
118                 // Copy what the reader should have done with its reservation.
119                 res := control.ReserveN(r.At, r.N)
120                 // If we don't have to wait with the control, the reader has gone too
121                 // fast.
122                 if res.Delay() > 0 {
123                         log.Printf("%d bytes not allowed at %s", r.N, time.Since(started))
124                         t.FailNow()
125                 }
126         }
127         assert.EqualValues(t, writeRounds*bytesPerRound, totalBytesRead)
128 }