diff options
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/query/errors.go | 17 | ||||
| -rw-r--r-- | pkg/query/optimizer.go | 100 | ||||
| -rw-r--r-- | pkg/query/outputs.go | 3 | ||||
| -rw-r--r-- | pkg/query/parser.go | 69 | ||||
| -rw-r--r-- | pkg/query/parser_test.go | 3 | ||||
| -rw-r--r-- | pkg/query/query.go | 25 |
6 files changed, 140 insertions, 77 deletions
diff --git a/pkg/query/errors.go b/pkg/query/errors.go index 0a9664f..35f8c19 100644 --- a/pkg/query/errors.go +++ b/pkg/query/errors.go @@ -8,12 +8,25 @@ import ( var ErrQueryFormat = errors.New("Incorrect query format") var ErrDatetimeTokenParse = errors.New("Unrecognized format for datetime token") +// output errors +var ErrUnrecognizedOutputToken = errors.New("Unrecognized output token") +var ErrExpectedMoreStringTokens = errors.New("Expected more string tokens") + +// optimizer errors +var ErrUnexpectedValueType = errors.New("Unexpected value type") +var ErrEmptyResult = errors.New("Queries are contradictory, will lead to an empty result") + + type TokenError struct { got Token gotPrev Token wantPrev string } +type CompileError struct { + s string +} + func (e *TokenError) Error() string { if e.wantPrev != "" { return fmt.Sprintf("Unexpected token: got %s, got previous %s want previous %s", e.got, e.gotPrev, e.wantPrev) @@ -21,3 +34,7 @@ func (e *TokenError) Error() string { return fmt.Sprintf("Unexpected token: got %s, got previous %s", e.got, e.gotPrev) } + +func (e *CompileError) Error() string { + return fmt.Sprintf("Compile error: %s", e.s) +} diff --git a/pkg/query/optimizer.go b/pkg/query/optimizer.go new file mode 100644 index 0000000..16f36dc --- /dev/null +++ b/pkg/query/optimizer.go @@ -0,0 +1,100 @@ +package query + +import ( + "slices" +) + +type Optimizer struct{} + +func StatementCmp(a Statement, b Statement) int { + catDiff := int(a.Category - b.Category) + opDiff := int(a.Operator - b.Operator) + negatedDiff := 0 + if a.Negated && !b.Negated { + negatedDiff = 1 + } else if !a.Negated && b.Negated { + negatedDiff = -1 + } + + return catDiff*100_000 + opDiff*100 + negatedDiff*10 + a.Value.Compare(b.Value) +} + +func StatementEq(a Statement, b Statement) bool { + a.Simplify() + b.Simplify() + return a.Category == b.Category && a.Operator == b.Operator && a.Negated == b.Negated && a.Value.Compare(b.Value) == 0 +} + +// Merge child clauses with their parents when applicable +func (o Optimizer) Flatten(root *Clause) { + stack := make([]*Clause, 0, len(root.Clauses)) + stack = append(stack, root) + for len(stack) != 0 { + top := len(stack) - 1 + node := stack[top] + stack = stack[:top] + + hasMerged := false + + // merge if only child clause + if len(node.Statements) == 0 && len(node.Clauses) == 1 { + child := node.Clauses[0] + + node.Operator = child.Operator + node.Statements = child.Statements + node.Clauses = child.Clauses + } + + // cannot be "modernized", node.Clauses is modified in loop + for i := 0; i < len(node.Clauses); i++ { + child := node.Clauses[i] + + // merge because of commutativity + if node.Operator == child.Operator { + hasMerged = true + node.Statements = append(node.Statements, child.Statements...) + node.Clauses = append(node.Clauses, child.Clauses...) + } else { + stack = append(stack, child) + } + } + + if hasMerged { + numChildren := len(stack) - top + if numChildren > 0 { + node.Clauses = slices.Grow(node.Clauses, numChildren) + node.Clauses = node.Clauses[:numChildren] + copy(node.Clauses, stack[top:top+numChildren]) + } else { + node.Clauses = nil + } + } + } +} + +func (o Optimizer) Compact(c *Clause) { + for clause := range c.DFS() { + clause.Statements = slices.CompactFunc(c.Statements, StatementEq) + } +} + +// if any claus is a strict equality/inequality noop all fuzzy operations +func strictEquality(clause Clause) error { + isStrict := slices.ContainsFunc(clause.Statements, func(stmt Statement) bool { + if stmt.Operator == OP_EQ || stmt.Operator == OP_NE { + return true + } + return false + }) + + if isStrict { + for i := range clause.Statements { + stmt := clause.Statements[i] + if stmt.Operator != OP_EQ && stmt.Operator != OP_NE { + clause.Statements[i] = Statement{} + } + } + } + + return nil +} diff --git a/pkg/query/outputs.go b/pkg/query/outputs.go index 3938f4f..00b9ddf 100644 --- a/pkg/query/outputs.go +++ b/pkg/query/outputs.go @@ -2,15 +2,12 @@ package query import ( "encoding/json" - "errors" "fmt" "strings" "github.com/jpappel/atlas/pkg/index" ) -var ErrUnrecognizedOutputToken = errors.New("Unrecognized output token") -var ErrExpectedMoreStringTokens = errors.New("Expected more string tokens") const DefaultOutputFormat string = "%p %T %d authors:%a tags:%t" diff --git a/pkg/query/parser.go b/pkg/query/parser.go index d7f4fdd..7ac9918 100644 --- a/pkg/query/parser.go +++ b/pkg/query/parser.go @@ -4,7 +4,6 @@ import ( "fmt" "iter" "os" - "slices" "strings" "time" @@ -217,59 +216,33 @@ func tokToOp(t queryTokenType) opType { } } +// Apply negation to a statements operator +func (s *Statement) Simplify() { + if s.Negated && s.Operator != OP_PIPE && s.Operator != OP_ARG && s.Operator != OP_AP { + s.Negated = false + switch s.Operator { + case OP_EQ: + s.Operator = OP_NE + case OP_NE: + s.Operator = OP_EQ + case OP_LT: + s.Operator = OP_GE + case OP_LE: + s.Operator = OP_GT + case OP_GE: + s.Operator = OP_LT + case OP_GT: + s.Operator = OP_LE + } + } +} + func (c Clause) String() string { b := &strings.Builder{} c.buildString(b, 0) return b.String() } -// Merge child clauses with their parents when applicable -func (root *Clause) Flatten() { - stack := make([]*Clause, 0, len(root.Clauses)) - stack = append(stack, root) - for len(stack) != 0 { - top := len(stack) - 1 - node := stack[top] - stack = stack[:top] - - hasMerged := false - - // merge if only child clause - if len(node.Statements) == 0 && len(node.Clauses) == 1 { - child := node.Clauses[0] - - node.Operator = child.Operator - node.Statements = child.Statements - node.Clauses = child.Clauses - } - - // cannot be "modernized", node.Clauses is modified in loop - for i := 0; i < len(node.Clauses); i++ { - child := node.Clauses[i] - - // merge because of commutativity - if node.Operator == child.Operator { - hasMerged = true - node.Statements = append(node.Statements, child.Statements...) - node.Clauses = append(node.Clauses, child.Clauses...) - } else { - stack = append(stack, child) - } - } - - if hasMerged { - numChildren := len(stack) - top - if numChildren > 0 { - node.Clauses = slices.Grow(node.Clauses, numChildren) - node.Clauses = node.Clauses[:numChildren] - copy(node.Clauses, stack[top:top+numChildren]) - } else { - node.Clauses = nil - } - } - } -} - func (c Clause) buildString(b *strings.Builder, level int) { writeIndent(b, level) b.WriteByte('(') diff --git a/pkg/query/parser_test.go b/pkg/query/parser_test.go index c1f25d4..6ea5c10 100644 --- a/pkg/query/parser_test.go +++ b/pkg/query/parser_test.go @@ -163,8 +163,9 @@ func TestClause_Flatten(t *testing.T) { }, } for _, tt := range tests { + o := query.Optimizer{} t.Run(tt.name, func(t *testing.T) { - tt.root.Flatten() + o.Flatten(tt.root) slices.SortFunc(tt.root.Statements, query.StatementCmp) slices.SortFunc(tt.expected.Statements, query.StatementCmp) diff --git a/pkg/query/query.go b/pkg/query/query.go index 5ddc724..57ba3e1 100644 --- a/pkg/query/query.go +++ b/pkg/query/query.go @@ -2,31 +2,6 @@ package query import "strings" -func Generate(ir *QueryIR) (any, error) { - // TODO: implement - return nil, nil -} - -func Compile(query string) (any, error) { - // TODO: logging - clause, err := Parse(Lex(query)) - if err != nil { - return nil, err - } - - ir, err := NewIR(*clause) - if err != nil { - return nil, err - } - - ir, err = Optimize(ir) - if err != nil { - return nil, err - } - - return Generate(ir) -} - func writeIndent(b *strings.Builder, level int) { for range level { b.WriteByte('\t') |
