]> Sergey Matveev's repositories - btrtrc.git/blob - rlreader_test.go
Fix download rate limiter and add test
[btrtrc.git] / rlreader_test.go
1 package torrent
2
3 import (
4         "io"
5         "log"
6         "math/rand"
7         "sync"
8         "time"
9
10         "github.com/bradfitz/iter"
11         "github.com/stretchr/testify/assert"
12         "github.com/stretchr/testify/require"
13         "golang.org/x/time/rate"
14
15         "testing"
16 )
17
18 func writeN(ws []io.Writer, n int) error {
19         b := make([]byte, n)
20         for _, w := range ws[1:] {
21                 n1 := rand.Intn(n)
22                 wn, err := w.Write(b[:n1])
23                 if wn != n1 {
24                         if err == nil {
25                                 panic(n1)
26                         }
27                         return err
28                 }
29                 n -= n1
30         }
31         wn, err := ws[0].Write(b[:n])
32         if wn != n {
33                 if err == nil {
34                         panic(n)
35                 }
36         }
37         return err
38 }
39
40 func TestRateLimitReaders(t *testing.T) {
41         const (
42                 numReaders     = 2
43                 bytesPerSecond = 100
44                 burst          = 5
45                 readSize       = 6
46                 writeRounds    = 10
47                 bytesPerRound  = 12
48         )
49         control := rate.NewLimiter(bytesPerSecond, burst)
50         shared := rate.NewLimiter(bytesPerSecond, burst)
51         var (
52                 ws []io.Writer
53                 cs []io.Closer
54         )
55         wg := sync.WaitGroup{}
56         type read struct {
57                 N int
58                 // When the read was allowed.
59                 At time.Time
60         }
61         reads := make(chan read)
62         done := make(chan struct{})
63         for range iter.N(numReaders) {
64                 r, w := io.Pipe()
65                 ws = append(ws, w)
66                 cs = append(cs, w)
67                 wg.Add(1)
68                 go func() {
69                         defer wg.Done()
70                         r := rateLimitedReader{
71                                 l: shared,
72                                 r: r,
73                         }
74                         b := make([]byte, readSize)
75                         for {
76                                 n, err := r.Read(b)
77                                 select {
78                                 case reads <- read{n, r.lastRead}:
79                                 case <-done:
80                                         return
81                                 }
82                                 if err == io.EOF {
83                                         return
84                                 }
85                                 if err != nil {
86                                         panic(err)
87                                 }
88                         }
89                 }()
90         }
91         closeAll := func() {
92                 for _, c := range cs {
93                         c.Close()
94                 }
95         }
96         defer func() {
97                 close(done)
98                 closeAll()
99                 wg.Wait()
100         }()
101         written := 0
102         go func() {
103                 for range iter.N(writeRounds) {
104                         err := writeN(ws, bytesPerRound)
105                         if err != nil {
106                                 log.Printf("error writing: %s", err)
107                                 break
108                         }
109                         written += bytesPerRound
110                 }
111                 closeAll()
112                 wg.Wait()
113                 close(reads)
114         }()
115         totalBytesRead := 0
116         started := time.Now()
117         for r := range reads {
118                 totalBytesRead += r.N
119                 require.False(t, r.At.IsZero())
120                 // Copy what the reader should have done with its reservation.
121                 res := control.ReserveN(r.At, r.N)
122                 // If we don't have to wait with the control, the reader has gone too
123                 // fast.
124                 if res.Delay() > 0 {
125                         log.Printf("%d bytes not allowed at %s", r.N, time.Since(started))
126                         t.FailNow()
127                 }
128         }
129         assert.EqualValues(t, writeRounds*bytesPerRound, totalBytesRead)
130 }