]> Sergey Matveev's repositories - btrtrc.git/blob - mmap_span/mmap_span.go
mmap_span: Synchronize access to memory maps to avoid race when unmapping
[btrtrc.git] / mmap_span / mmap_span.go
1 package mmap_span
2
3 import (
4         "io"
5         "log"
6         "sync"
7
8         "github.com/edsrzf/mmap-go"
9 )
10
11 type segment struct {
12         *mmap.MMap
13 }
14
15 func (s segment) Size() int64 {
16         return int64(len(*s.MMap))
17 }
18
19 type MMapSpan struct {
20         mu sync.RWMutex
21         span
22 }
23
24 func (ms *MMapSpan) Append(mmap mmap.MMap) {
25         ms.span = append(ms.span, segment{&mmap})
26 }
27
28 func (ms *MMapSpan) Close() error {
29         ms.mu.Lock()
30         defer ms.mu.Unlock()
31         for _, mMap := range ms.span {
32                 err := mMap.(segment).Unmap()
33                 if err != nil {
34                         log.Print(err)
35                 }
36         }
37         return nil
38 }
39
40 func (ms *MMapSpan) Size() (ret int64) {
41         ms.mu.RLock()
42         defer ms.mu.RUnlock()
43         for _, seg := range ms.span {
44                 ret += seg.Size()
45         }
46         return
47 }
48
49 func (ms *MMapSpan) ReadAt(p []byte, off int64) (n int, err error) {
50         ms.mu.RLock()
51         defer ms.mu.RUnlock()
52         ms.ApplyTo(off, func(intervalOffset int64, interval sizer) (stop bool) {
53                 _n := copy(p, (*interval.(segment).MMap)[intervalOffset:])
54                 p = p[_n:]
55                 n += _n
56                 return len(p) == 0
57         })
58         if len(p) != 0 {
59                 err = io.EOF
60         }
61         return
62 }
63
64 func (ms *MMapSpan) WriteAt(p []byte, off int64) (n int, err error) {
65         ms.mu.RLock()
66         defer ms.mu.RUnlock()
67         ms.ApplyTo(off, func(iOff int64, i sizer) (stop bool) {
68                 mMap := i.(segment)
69                 _n := copy((*mMap.MMap)[iOff:], p)
70                 // err = mMap.Sync(gommap.MS_ASYNC)
71                 // if err != nil {
72                 //      return true
73                 // }
74                 p = p[_n:]
75                 n += _n
76                 return len(p) == 0
77         })
78         if err != nil && len(p) != 0 {
79                 err = io.ErrShortWrite
80         }
81         return
82 }