]> Sergey Matveev's repositories - btrtrc.git/blob - rlreader_test.go
sortimports
[btrtrc.git] / rlreader_test.go
1 package torrent
2
3 import (
4         "io"
5         "log"
6         "math/rand"
7         "sync"
8         "testing"
9         "time"
10
11         "github.com/bradfitz/iter"
12         "github.com/stretchr/testify/assert"
13         "github.com/stretchr/testify/require"
14         "golang.org/x/time/rate"
15 )
16
17 func writeN(ws []io.Writer, n int) error {
18         b := make([]byte, n)
19         for _, w := range ws[1:] {
20                 n1 := rand.Intn(n)
21                 wn, err := w.Write(b[:n1])
22                 if wn != n1 {
23                         if err == nil {
24                                 panic(n1)
25                         }
26                         return err
27                 }
28                 n -= n1
29         }
30         wn, err := ws[0].Write(b[:n])
31         if wn != n {
32                 if err == nil {
33                         panic(n)
34                 }
35         }
36         return err
37 }
38
39 func TestRateLimitReaders(t *testing.T) {
40         const (
41                 numReaders     = 2
42                 bytesPerSecond = 100
43                 burst          = 5
44                 readSize       = 6
45                 writeRounds    = 10
46                 bytesPerRound  = 12
47         )
48         control := rate.NewLimiter(bytesPerSecond, burst)
49         shared := rate.NewLimiter(bytesPerSecond, burst)
50         var (
51                 ws []io.Writer
52                 cs []io.Closer
53         )
54         wg := sync.WaitGroup{}
55         type read struct {
56                 N int
57                 // When the read was allowed.
58                 At time.Time
59         }
60         reads := make(chan read)
61         done := make(chan struct{})
62         for range iter.N(numReaders) {
63                 r, w := io.Pipe()
64                 ws = append(ws, w)
65                 cs = append(cs, w)
66                 wg.Add(1)
67                 go func() {
68                         defer wg.Done()
69                         r := rateLimitedReader{
70                                 l: shared,
71                                 r: r,
72                         }
73                         b := make([]byte, readSize)
74                         for {
75                                 n, err := r.Read(b)
76                                 select {
77                                 case reads <- read{n, r.lastRead}:
78                                 case <-done:
79                                         return
80                                 }
81                                 if err == io.EOF {
82                                         return
83                                 }
84                                 if err != nil {
85                                         panic(err)
86                                 }
87                         }
88                 }()
89         }
90         closeAll := func() {
91                 for _, c := range cs {
92                         c.Close()
93                 }
94         }
95         defer func() {
96                 close(done)
97                 closeAll()
98                 wg.Wait()
99         }()
100         written := 0
101         go func() {
102                 for range iter.N(writeRounds) {
103                         err := writeN(ws, bytesPerRound)
104                         if err != nil {
105                                 log.Printf("error writing: %s", err)
106                                 break
107                         }
108                         written += bytesPerRound
109                 }
110                 closeAll()
111                 wg.Wait()
112                 close(reads)
113         }()
114         totalBytesRead := 0
115         started := time.Now()
116         for r := range reads {
117                 totalBytesRead += r.N
118                 require.False(t, r.At.IsZero())
119                 // Copy what the reader should have done with its reservation.
120                 res := control.ReserveN(r.At, r.N)
121                 // If we don't have to wait with the control, the reader has gone too
122                 // fast.
123                 if res.Delay() > 0 {
124                         log.Printf("%d bytes not allowed at %s", r.N, time.Since(started))
125                         t.FailNow()
126                 }
127         }
128         assert.EqualValues(t, writeRounds*bytesPerRound, totalBytesRead)
129 }