aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJP Appel <jeanpierre.appel01@gmail.com>2025-06-14 12:49:04 -0400
committerJP Appel <jeanpierre.appel01@gmail.com>2025-06-14 12:53:45 -0400
commit97a2d99d0a3f1609d3d2264e4e54c119ec3801ff (patch)
tree09ad72618de4da04428cfa09ca200e495bd1da5d
parent06d091cc609e90974f8da7e7ae153f3c2a83ee46 (diff)
Move clause tree optimizations
-rw-r--r--pkg/query/errors.go17
-rw-r--r--pkg/query/optimizer.go100
-rw-r--r--pkg/query/outputs.go3
-rw-r--r--pkg/query/parser.go69
-rw-r--r--pkg/query/parser_test.go3
-rw-r--r--pkg/query/query.go25
6 files changed, 140 insertions, 77 deletions
diff --git a/pkg/query/errors.go b/pkg/query/errors.go
index 0a9664f..35f8c19 100644
--- a/pkg/query/errors.go
+++ b/pkg/query/errors.go
@@ -8,12 +8,25 @@ import (
var ErrQueryFormat = errors.New("Incorrect query format")
var ErrDatetimeTokenParse = errors.New("Unrecognized format for datetime token")
+// output errors
+var ErrUnrecognizedOutputToken = errors.New("Unrecognized output token")
+var ErrExpectedMoreStringTokens = errors.New("Expected more string tokens")
+
+// optimizer errors
+var ErrUnexpectedValueType = errors.New("Unexpected value type")
+var ErrEmptyResult = errors.New("Queries are contradictory, will lead to an empty result")
+
+
type TokenError struct {
got Token
gotPrev Token
wantPrev string
}
+type CompileError struct {
+ s string
+}
+
func (e *TokenError) Error() string {
if e.wantPrev != "" {
return fmt.Sprintf("Unexpected token: got %s, got previous %s want previous %s", e.got, e.gotPrev, e.wantPrev)
@@ -21,3 +34,7 @@ func (e *TokenError) Error() string {
return fmt.Sprintf("Unexpected token: got %s, got previous %s", e.got, e.gotPrev)
}
+
+func (e *CompileError) Error() string {
+ return fmt.Sprintf("Compile error: %s", e.s)
+}
diff --git a/pkg/query/optimizer.go b/pkg/query/optimizer.go
new file mode 100644
index 0000000..16f36dc
--- /dev/null
+++ b/pkg/query/optimizer.go
@@ -0,0 +1,100 @@
+package query
+
+import (
+ "slices"
+)
+
+type Optimizer struct{}
+
+func StatementCmp(a Statement, b Statement) int {
+ catDiff := int(a.Category - b.Category)
+ opDiff := int(a.Operator - b.Operator)
+ negatedDiff := 0
+ if a.Negated && !b.Negated {
+ negatedDiff = 1
+ } else if !a.Negated && b.Negated {
+ negatedDiff = -1
+ }
+
+ return catDiff*100_000 + opDiff*100 + negatedDiff*10 + a.Value.Compare(b.Value)
+}
+
+func StatementEq(a Statement, b Statement) bool {
+ a.Simplify()
+ b.Simplify()
+ return a.Category == b.Category && a.Operator == b.Operator && a.Negated == b.Negated && a.Value.Compare(b.Value) == 0
+}
+
+// Merge child clauses with their parents when applicable
+func (o Optimizer) Flatten(root *Clause) {
+ stack := make([]*Clause, 0, len(root.Clauses))
+ stack = append(stack, root)
+ for len(stack) != 0 {
+ top := len(stack) - 1
+ node := stack[top]
+ stack = stack[:top]
+
+ hasMerged := false
+
+ // merge if only child clause
+ if len(node.Statements) == 0 && len(node.Clauses) == 1 {
+ child := node.Clauses[0]
+
+ node.Operator = child.Operator
+ node.Statements = child.Statements
+ node.Clauses = child.Clauses
+ }
+
+ // cannot be "modernized", node.Clauses is modified in loop
+ for i := 0; i < len(node.Clauses); i++ {
+ child := node.Clauses[i]
+
+ // merge because of commutativity
+ if node.Operator == child.Operator {
+ hasMerged = true
+ node.Statements = append(node.Statements, child.Statements...)
+ node.Clauses = append(node.Clauses, child.Clauses...)
+ } else {
+ stack = append(stack, child)
+ }
+ }
+
+ if hasMerged {
+ numChildren := len(stack) - top
+ if numChildren > 0 {
+ node.Clauses = slices.Grow(node.Clauses, numChildren)
+ node.Clauses = node.Clauses[:numChildren]
+ copy(node.Clauses, stack[top:top+numChildren])
+ } else {
+ node.Clauses = nil
+ }
+ }
+ }
+}
+
+func (o Optimizer) Compact(c *Clause) {
+ for clause := range c.DFS() {
+ clause.Statements = slices.CompactFunc(c.Statements, StatementEq)
+ }
+}
+
+// if any claus is a strict equality/inequality noop all fuzzy operations
+func strictEquality(clause Clause) error {
+ isStrict := slices.ContainsFunc(clause.Statements, func(stmt Statement) bool {
+ if stmt.Operator == OP_EQ || stmt.Operator == OP_NE {
+ return true
+ }
+ return false
+ })
+
+ if isStrict {
+ for i := range clause.Statements {
+ stmt := clause.Statements[i]
+ if stmt.Operator != OP_EQ && stmt.Operator != OP_NE {
+ clause.Statements[i] = Statement{}
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/pkg/query/outputs.go b/pkg/query/outputs.go
index 3938f4f..00b9ddf 100644
--- a/pkg/query/outputs.go
+++ b/pkg/query/outputs.go
@@ -2,15 +2,12 @@ package query
import (
"encoding/json"
- "errors"
"fmt"
"strings"
"github.com/jpappel/atlas/pkg/index"
)
-var ErrUnrecognizedOutputToken = errors.New("Unrecognized output token")
-var ErrExpectedMoreStringTokens = errors.New("Expected more string tokens")
const DefaultOutputFormat string = "%p %T %d authors:%a tags:%t"
diff --git a/pkg/query/parser.go b/pkg/query/parser.go
index d7f4fdd..7ac9918 100644
--- a/pkg/query/parser.go
+++ b/pkg/query/parser.go
@@ -4,7 +4,6 @@ import (
"fmt"
"iter"
"os"
- "slices"
"strings"
"time"
@@ -217,59 +216,33 @@ func tokToOp(t queryTokenType) opType {
}
}
+// Apply negation to a statements operator
+func (s *Statement) Simplify() {
+ if s.Negated && s.Operator != OP_PIPE && s.Operator != OP_ARG && s.Operator != OP_AP {
+ s.Negated = false
+ switch s.Operator {
+ case OP_EQ:
+ s.Operator = OP_NE
+ case OP_NE:
+ s.Operator = OP_EQ
+ case OP_LT:
+ s.Operator = OP_GE
+ case OP_LE:
+ s.Operator = OP_GT
+ case OP_GE:
+ s.Operator = OP_LT
+ case OP_GT:
+ s.Operator = OP_LE
+ }
+ }
+}
+
func (c Clause) String() string {
b := &strings.Builder{}
c.buildString(b, 0)
return b.String()
}
-// Merge child clauses with their parents when applicable
-func (root *Clause) Flatten() {
- stack := make([]*Clause, 0, len(root.Clauses))
- stack = append(stack, root)
- for len(stack) != 0 {
- top := len(stack) - 1
- node := stack[top]
- stack = stack[:top]
-
- hasMerged := false
-
- // merge if only child clause
- if len(node.Statements) == 0 && len(node.Clauses) == 1 {
- child := node.Clauses[0]
-
- node.Operator = child.Operator
- node.Statements = child.Statements
- node.Clauses = child.Clauses
- }
-
- // cannot be "modernized", node.Clauses is modified in loop
- for i := 0; i < len(node.Clauses); i++ {
- child := node.Clauses[i]
-
- // merge because of commutativity
- if node.Operator == child.Operator {
- hasMerged = true
- node.Statements = append(node.Statements, child.Statements...)
- node.Clauses = append(node.Clauses, child.Clauses...)
- } else {
- stack = append(stack, child)
- }
- }
-
- if hasMerged {
- numChildren := len(stack) - top
- if numChildren > 0 {
- node.Clauses = slices.Grow(node.Clauses, numChildren)
- node.Clauses = node.Clauses[:numChildren]
- copy(node.Clauses, stack[top:top+numChildren])
- } else {
- node.Clauses = nil
- }
- }
- }
-}
-
func (c Clause) buildString(b *strings.Builder, level int) {
writeIndent(b, level)
b.WriteByte('(')
diff --git a/pkg/query/parser_test.go b/pkg/query/parser_test.go
index c1f25d4..6ea5c10 100644
--- a/pkg/query/parser_test.go
+++ b/pkg/query/parser_test.go
@@ -163,8 +163,9 @@ func TestClause_Flatten(t *testing.T) {
},
}
for _, tt := range tests {
+ o := query.Optimizer{}
t.Run(tt.name, func(t *testing.T) {
- tt.root.Flatten()
+ o.Flatten(tt.root)
slices.SortFunc(tt.root.Statements, query.StatementCmp)
slices.SortFunc(tt.expected.Statements, query.StatementCmp)
diff --git a/pkg/query/query.go b/pkg/query/query.go
index 5ddc724..57ba3e1 100644
--- a/pkg/query/query.go
+++ b/pkg/query/query.go
@@ -2,31 +2,6 @@ package query
import "strings"
-func Generate(ir *QueryIR) (any, error) {
- // TODO: implement
- return nil, nil
-}
-
-func Compile(query string) (any, error) {
- // TODO: logging
- clause, err := Parse(Lex(query))
- if err != nil {
- return nil, err
- }
-
- ir, err := NewIR(*clause)
- if err != nil {
- return nil, err
- }
-
- ir, err = Optimize(ir)
- if err != nil {
- return nil, err
- }
-
- return Generate(ir)
-}
-
func writeIndent(b *strings.Builder, level int) {
for range level {
b.WriteByte('\t')