aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pkg/util/util.go62
-rw-r--r--pkg/util/util_test.go26
2 files changed, 88 insertions, 0 deletions
diff --git a/pkg/util/util.go b/pkg/util/util.go
index 577401e..b26b46a 100644
--- a/pkg/util/util.go
+++ b/pkg/util/util.go
@@ -2,6 +2,7 @@ package util
import (
"iter"
+ "math"
"time"
)
@@ -100,3 +101,64 @@ func BackwardsFilterIter[E any](s []E, cond func(e E) bool) iter.Seq2[int, E] {
}
}
}
+
+// A Levenshtein distance implementation based off of
+//
+// https://en.wikipedia.org/wiki/Levenshtein_distance#Iterative_with_full_matrix
+// PERF: more performant implementations exist
+func LevensteinDistance(s, t string) int {
+ m, n := len(s), len(t)
+ d := make([][]int, m+1)
+ for i := range m + 1 {
+ d[i] = make([]int, n+1)
+ }
+
+ for i := range m {
+ d[i+1][0] = i
+ }
+ for j := range n {
+ d[0][j+1] = j
+ }
+
+ var subCost int
+ for j := range n {
+ for i := range m {
+ if s[i] == t[j] {
+ subCost = 0
+ } else {
+ subCost = 1
+ }
+
+ del := d[i][j+1] + 1
+ insert := d[i+1][j] + 1
+ sub := d[i][j] + subCost
+ d[i+1][j+1] = min(del, insert, sub)
+ }
+ }
+
+ return d[m][n]
+}
+
+// Find nearest element of a slice using cmp, returns the found element and
+// if the distance is below ceil
+func Nearest[E any](candidate E, valid []E, cmp func(E, E) int, ceil int) (E, bool) {
+ minDistance := math.MaxInt
+ minIdx := -1
+ var d int
+ for i, e := range valid {
+ if sd := cmp(candidate, e); sd < 0 {
+ d = -sd
+ } else {
+ d = sd
+ }
+ if d < minDistance {
+ minDistance = d
+ minIdx = i
+ }
+ }
+
+ if minIdx < 0 {
+ return candidate, false
+ }
+ return valid[minIdx], minDistance < ceil
+}
diff --git a/pkg/util/util_test.go b/pkg/util/util_test.go
new file mode 100644
index 0000000..36bb5c9
--- /dev/null
+++ b/pkg/util/util_test.go
@@ -0,0 +1,26 @@
+package util_test
+
+import (
+ "github.com/jpappel/atlas/pkg/util"
+ "testing"
+)
+
+func TestLevensteinDistance(t *testing.T) {
+ tests := []struct {
+ s string
+ t string
+ want int
+ }{
+ {"sitting", "kitten", 3},
+ {"Saturday", "Sunday", 3},
+ {"hello", "kelm", 3},
+ }
+ for _, tt := range tests {
+ t.Run(tt.s+" "+tt.t, func(t *testing.T) {
+ got := util.LevensteinDistance(tt.s, tt.t)
+ if got != tt.want {
+ t.Errorf("LevensteinDistance() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}