aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/server
diff options
context:
space:
mode:
authorJP Appel <jeanpierre.appel01@gmail.com>2025-07-27 20:37:27 -0400
committerJP Appel <jeanpierre.appel01@gmail.com>2025-07-27 20:37:27 -0400
commit75c3d32881a3b382ab6b63f25d177d40a2ca4256 (patch)
tree20c0305d556e6fa1f23316a9aed9b77e6429ef74 /pkg/server
parent3b84e41e8fec9ecf5c9f4885c8908cf94b8999f5 (diff)
Refactor server over unix sockets to use SOCK_STREAM
Diffstat (limited to 'pkg/server')
-rw-r--r--pkg/server/unix.go179
1 files changed, 118 insertions, 61 deletions
diff --git a/pkg/server/unix.go b/pkg/server/unix.go
index 3b2ee9b..a46320d 100644
--- a/pkg/server/unix.go
+++ b/pkg/server/unix.go
@@ -1,100 +1,157 @@
package server
import (
- "bytes"
"context"
+ "fmt"
"log/slog"
"net"
- "os"
+ "sync"
"github.com/jpappel/atlas/pkg/data"
+ "github.com/jpappel/atlas/pkg/index"
"github.com/jpappel/atlas/pkg/query"
)
-// datagram based unix server
+const (
+ START_HEADER byte = 1
+ START_BODY byte = 2
+ END_BODY byte = 3
+ END_MSG byte = 4
+ END_QUERY byte = 5
+)
+
type UnixServer struct {
- Addr string
- Db *data.Query
- conn *net.UnixConn
+ Addr string
+ Db *data.Query
+ WorkersPerConn uint
+ ln *net.UnixListener
+ conns map[uint64]*net.UnixConn
+ lock sync.RWMutex
+ bufPool sync.Pool
}
func (s *UnixServer) ListenAndServe() error {
- serverAddr := s.Addr + "_server.sock"
- clientAddr := s.Addr + "_client.sock"
-
- var err error
- s.conn, err = net.ListenUnixgram(
- "unixgram",
- &net.UnixAddr{Name: serverAddr, Net: "Unix"},
- )
+ addr, err := net.ResolveUnixAddr("unix", s.Addr)
if err != nil {
return err
}
- defer os.RemoveAll(s.Addr)
- slog.Info("Listening on", slog.String("addr", s.Addr))
- var remote *net.UnixAddr
- remote, err = net.ResolveUnixAddr("unixgram", clientAddr)
+ s.ln, err = net.ListenUnix("unix", addr)
if err != nil {
- panic(err)
+ return err
}
- // FIXME: limits queries to 1kb, might have some data overflow into next msg
- buf := make([]byte, 1024)
+
+ s.conns = make(map[uint64]*net.UnixConn)
+ s.bufPool.New = func() any {
+ return make([]byte, 1024)
+ }
+
+ var connId uint64
for {
- n, _, err := s.conn.ReadFromUnix(buf)
+ conn, err := s.ln.AcceptUnix()
if err != nil {
- return err
+ break
+ }
+ connId++
+ slog.Info("New connection", slog.Uint64("connId", connId))
+
+ s.lock.Lock()
+ s.conns[connId] = conn
+ s.lock.Unlock()
+
+ go s.handleConn(conn, connId)
+ }
+
+ return nil
+}
+
+func (s *UnixServer) writeError(conn *net.UnixConn, msg string) {
+ conn.Write(fmt.Append([]byte{START_HEADER}, "Error handling query"))
+ conn.Write([]byte{START_BODY, END_BODY})
+ conn.Write([]byte(msg))
+ conn.Write([]byte{END_MSG})
+}
+
+func (s *UnixServer) writeResults(conn *net.UnixConn, docs map[string]*index.Document) {
+ defer conn.Write([]byte{END_MSG})
+ conn.Write(fmt.Appendf([]byte{START_HEADER}, "Num Docs: %d", len(docs)))
+ conn.Write([]byte{START_BODY})
+ defer conn.Write([]byte{END_BODY})
+
+ o := query.DefaultOutput{}
+ for _, doc := range docs {
+ if _, err := o.WriteDoc(conn, doc); err != nil {
+ slog.Error("Failed to write doc",
+ slog.String("err", err.Error()),
+ )
+ break
+ }
+ }
+}
+
+func (s *UnixServer) handleConn(conn *net.UnixConn, id uint64) {
+ defer func(id uint64) {
+ s.lock.Lock()
+ delete(s.conns, id)
+ s.lock.Unlock()
+ }(id)
+
+ buf := s.bufPool.Get().([]byte)
+ defer s.bufPool.Put(buf)
+ defer slog.Info("Closing connection",
+ slog.String("local", conn.LocalAddr().String()),
+ )
+
+ for {
+ slog.Debug("Waiting for query")
+ n, err := conn.Read(buf)
+ if n == 0 || err != nil {
+ break
}
buf = buf[:n]
- queryTxt := string(buf)
- slog.Debug("New message",
- slog.String("msg", queryTxt),
- slog.String("local", s.conn.LocalAddr().String()),
- slog.String("remote", remote.String()),
+ if buf[len(buf)-1] != 5 {
+ slog.Info("Missing ENQ at end of message")
+ break
+ }
+
+ queryTxt := string(buf[:len(buf)-1])
+ slog.Debug("Recieved query",
+ slog.String("query", queryTxt),
)
- // TODO: set reasonable numWorkers
- // TODO: rwrite error to remote
- artifact, err := query.Compile(queryTxt, 0, 2)
+ // TODO: cache compilation artifacts
+ artifact, err := query.Compile(queryTxt, 0, s.WorkersPerConn)
if err != nil {
- slog.Error("Failed to compile query", slog.String("err", err.Error()))
- return err
+ slog.Warn("Failed to compile query",
+ slog.String("err", err.Error()))
+ s.writeError(conn, "query compilation error")
+ break
}
- // TODO: write error to remote
docs, err := s.Db.Execute(artifact)
if err != nil {
- slog.Error("Failed to execute query",
- slog.String("err", err.Error()))
- return err
+ slog.Warn("Failed to execute query",
+ slog.String("query", queryTxt),
+ slog.String("err", err.Error()),
+ )
+ s.writeError(conn, "query execution error")
+ break
}
- buf := &bytes.Buffer{}
- o := query.DefaultOutput{}
- for _, doc := range docs {
- n, err = o.WriteDoc(buf, doc)
- if err != nil {
- return err
- }
-
- b := buf.Bytes()
- remaining := len(b)
- offset := 0
- for remaining > 0 {
- n, err := s.conn.WriteToUnix(b[offset:remaining], remote)
- if err != nil {
- return err
- }
- remaining -= n
- offset += n
- }
- buf.Reset()
- }
- // EOF
- s.conn.WriteToUnix([]byte{4}, remote)
+ slog.Debug("Sending results")
+ s.writeResults(conn, docs)
}
}
func (s *UnixServer) Shutdown(ctx context.Context) error {
- return s.conn.Close()
+ s.ln.Close()
+ s.lock.RLock()
+ defer s.lock.RUnlock()
+
+ for _, conn := range s.conns {
+ conn.Write([]byte("Closing Server"))
+ conn.Close()
+ }
+
+ return nil
}