1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
|
package server
import (
"context"
"fmt"
"log/slog"
"net"
"sync"
"github.com/jpappel/atlas/pkg/data"
"github.com/jpappel/atlas/pkg/index"
"github.com/jpappel/atlas/pkg/query"
)
const (
START_HEADER byte = 1
START_BODY byte = 2
END_BODY byte = 3
END_MSG byte = 4
END_QUERY byte = 5
)
type UnixServer struct {
Addr string
Db *data.Query
WorkersPerConn uint
ln *net.UnixListener
conns map[uint64]*net.UnixConn
lock sync.RWMutex
bufPool sync.Pool
}
func (s *UnixServer) ListenAndServe() error {
addr, err := net.ResolveUnixAddr("unix", s.Addr)
if err != nil {
return err
}
s.ln, err = net.ListenUnix("unix", addr)
if err != nil {
return err
}
s.conns = make(map[uint64]*net.UnixConn)
s.bufPool.New = func() any {
return make([]byte, 1024)
}
var connId uint64
for {
conn, err := s.ln.AcceptUnix()
if err != nil {
break
}
connId++
slog.Info("New connection", slog.Uint64("connId", connId))
s.lock.Lock()
s.conns[connId] = conn
s.lock.Unlock()
go s.handleConn(conn, connId)
}
return nil
}
func (s *UnixServer) writeError(conn *net.UnixConn, msg string) {
conn.Write(fmt.Append([]byte{START_HEADER}, "Error handling query"))
conn.Write([]byte{START_BODY, END_BODY})
conn.Write([]byte(msg))
conn.Write([]byte{END_MSG})
}
func (s *UnixServer) writeResults(conn *net.UnixConn, docs map[string]*index.Document) {
defer conn.Write([]byte{END_MSG})
conn.Write(fmt.Appendf([]byte{START_HEADER}, "Num Docs: %d", len(docs)))
conn.Write([]byte{START_BODY})
defer conn.Write([]byte{END_BODY})
o := query.DefaultOutput{}
for _, doc := range docs {
if _, err := o.WriteDoc(conn, doc); err != nil {
slog.Error("Failed to write doc",
slog.String("err", err.Error()),
)
break
}
}
}
func (s *UnixServer) handleConn(conn *net.UnixConn, id uint64) {
defer func(id uint64) {
s.lock.Lock()
delete(s.conns, id)
s.lock.Unlock()
}(id)
buf := s.bufPool.Get().([]byte)
defer s.bufPool.Put(buf)
defer slog.Info("Closing connection",
slog.String("local", conn.LocalAddr().String()),
)
for {
slog.Debug("Waiting for query")
n, err := conn.Read(buf)
if n == 0 || err != nil {
break
}
buf = buf[:n]
if buf[len(buf)-1] != 5 {
slog.Info("Missing ENQ at end of message")
break
}
queryTxt := string(buf[:len(buf)-1])
slog.Debug("Recieved query",
slog.String("query", queryTxt),
)
// TODO: cache compilation artifacts
artifact, err := query.Compile(queryTxt, 0, s.WorkersPerConn)
if err != nil {
slog.Warn("Failed to compile query",
slog.String("err", err.Error()))
s.writeError(conn, "query compilation error")
break
}
docs, err := s.Db.Execute(artifact)
if err != nil {
slog.Warn("Failed to execute query",
slog.String("query", queryTxt),
slog.String("err", err.Error()),
)
s.writeError(conn, "query execution error")
break
}
slog.Debug("Sending results")
s.writeResults(conn, docs)
}
}
func (s *UnixServer) Shutdown(ctx context.Context) error {
s.ln.Close()
s.lock.RLock()
defer s.lock.RUnlock()
for _, conn := range s.conns {
conn.Write([]byte("Closing Server"))
conn.Close()
}
return nil
}
|