aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/server/unix.go
blob: 9d8c266357277257d737e1844ba8a8b06584e0f1 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
package server

import (
	"context"
	"fmt"
	"log/slog"
	"net"
	"sync"
	"time"

	"github.com/jpappel/atlas/pkg/data"
	"github.com/jpappel/atlas/pkg/index"
	"github.com/jpappel/atlas/pkg/query"
)

const (
	START_HEADER byte = 1
	START_BODY   byte = 2
	END_BODY     byte = 3
	END_MSG      byte = 4
	END_QUERY    byte = 5
)

type UnixServer struct {
	Addr           string
	Db             *data.Query
	WorkersPerConn uint
	ln             *net.UnixListener
	conns          map[uint64]*net.UnixConn
	lock           sync.RWMutex
	bufPool        sync.Pool
}

func (s *UnixServer) ListenAndServe() error {
	addr, err := net.ResolveUnixAddr("unix", s.Addr)
	if err != nil {
		return err
	}

	s.ln, err = net.ListenUnix("unix", addr)
	if err != nil {
		return err
	}

	s.conns = make(map[uint64]*net.UnixConn)
	s.bufPool.New = func() any {
		return make([]byte, 1024)
	}

	var connId uint64
	for {
		conn, err := s.ln.AcceptUnix()
		if err != nil {
			break
		}
		connId++
		slog.Info("New connection", slog.Uint64("connId", connId))

		s.lock.Lock()
		s.conns[connId] = conn
		s.lock.Unlock()

		go s.handleConn(conn, connId)
	}

	return nil
}

func (s *UnixServer) writeError(conn *net.UnixConn, msg string) {
	conn.Write(fmt.Append([]byte{START_HEADER}, "Error handling query"))
	conn.Write([]byte{START_BODY, END_BODY})
	conn.Write([]byte(msg))
	conn.Write([]byte{END_MSG})
}

func (s *UnixServer) writeResults(conn *net.UnixConn, docs map[string]*index.Document) {
	defer conn.Write([]byte{END_MSG})
	conn.Write(fmt.Appendf([]byte{START_HEADER}, "Num Docs: %d", len(docs)))
	conn.Write([]byte{START_BODY})
	defer conn.Write([]byte{END_BODY})

	o := query.DefaultOutput{}
	for _, doc := range docs {
		if _, err := o.WriteDoc(conn, doc); err != nil {
			slog.Error("Failed to write doc",
				slog.String("err", err.Error()),
			)
			break
		}
	}
}

func (s *UnixServer) handleConn(conn *net.UnixConn, id uint64) {
	defer func(id uint64) {
		s.lock.Lock()
		delete(s.conns, id)
		s.lock.Unlock()
	}(id)

	buf := s.bufPool.Get().([]byte)
	defer s.bufPool.Put(buf)
	defer slog.Info("Closing connection",
		slog.String("local", conn.LocalAddr().String()),
	)

	for {
		slog.Debug("Waiting for query")
		n, err := conn.Read(buf)
		if n == 0 || err != nil {
			break
		}
		buf = buf[:n]
		if buf[len(buf)-1] != 5 {
			slog.Info("Missing ENQ at end of message")
			break
		}

		queryTxt := string(buf[:len(buf)-1])
		slog.Debug("Recieved query",
			slog.String("query", queryTxt),
		)

		// TODO: cache compilation artifacts
		artifact, err := query.Compile(queryTxt, 0, s.WorkersPerConn)
		if err != nil {
			slog.Warn("Failed to compile query",
				slog.String("err", err.Error()))
			s.writeError(conn, "query compilation error")
			break
		}

		ctx, cancel := context.WithTimeout(context.Background(), 400*time.Millisecond)
		docs, err := s.Db.Execute(ctx, artifact)
		if err != nil {
			slog.Warn("Failed to execute query",
				slog.String("query", queryTxt),
				slog.String("err", err.Error()),
			)
			s.writeError(conn, "query execution error")
			cancel()
			break
		}
		cancel()

		slog.Debug("Sending results")
		s.writeResults(conn, docs)
	}
}

func (s *UnixServer) Shutdown(ctx context.Context) error {
	s.ln.Close()
	s.lock.RLock()
	defer s.lock.RUnlock()

	for _, conn := range s.conns {
		conn.Write([]byte("Closing Server"))
		conn.Close()
	}

	return nil
}