package index import ( "bytes" "io" "os" "testing" ) func noYamlHeader() io.ReadSeeker { buf := []byte("just some text") return bytes.NewReader(buf) } func incompleteYamlHeader() io.ReadSeeker { buf := []byte("---\nfoo:bar\ntitle:bizbaz\nauthor:\n-JP Appel\n---") return bytes.NewReader(buf) } func completeYamlHeader() io.ReadSeeker { buf := []byte("---\nfoo:bar\ntitle:bizbaz\nauthor:\n-JP Appel\n---\n") return bytes.NewReader(buf) } func trailingYamlHeader() io.ReadSeeker { buf := []byte("---\nfoo:bar\ntitle:bizbaz\nauthor:\n-JP Appel\n---\nhere are some content\nanother line of text") return bytes.NewReader(buf) } func extensionless(t *testing.T) infoPath { root := t.TempDir() path := root + "/" + "afile" f, err := os.Create(path) if err != nil { t.Fatal(err) } defer f.Close() if _, err := f.WriteString("this is a file"); err != nil { t.Fatal(err) } info, err := f.Stat() if err != nil { t.Fatal(err) } return infoPath{path, info} } func markdownExtension(t *testing.T) infoPath { root := t.TempDir() path := root + "/" + "a.md" f, err := os.Create(path) if err != nil { t.Fatal(err) } defer f.Close() info, err := f.Stat() if err != nil { t.Fatal(err) } return infoPath{path, info} } func TestYamlHeaderFilter(t *testing.T) { tests := []struct { name string r io.ReadSeeker want bool }{ {"completeYamlHeader", completeYamlHeader(), true}, {"trailingYamlHeader", trailingYamlHeader(), true}, {"noYamlHeader", noYamlHeader(), false}, {"incompleteYamlHeader", incompleteYamlHeader(), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := YamlHeaderFilter(infoPath{}, tt.r) if got != tt.want { t.Errorf("YamlHeaderFilter() = %v, want %v", got, tt.want) } }) } } func TestExtensionFilter(t *testing.T) { tests := []struct { name string infoGen func(*testing.T) infoPath ext string want bool }{ {"no extension, accept .md", extensionless, ".md", false}, {"no extension, accept all", extensionless, "", true}, {"markdown, accept .md", markdownExtension, ".md", true}, {"makdown, accept .png", markdownExtension, ".png", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { filter := NewExtensionFilter(tt.ext) ip := tt.infoGen(t) got := filter(ip, nil) if got != tt.want { t.Errorf("ExtensionFilter() = %v, want %v", got, tt.want) } }) } }