Skip to content

Commit

Permalink
query extractor (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
cp-20 authored Jan 21, 2025
1 parent c217341 commit 80c4268
Show file tree
Hide file tree
Showing 7 changed files with 457 additions and 7 deletions.
69 changes: 69 additions & 0 deletions cli/extractor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package main

import (
"fmt"
"os"
"path/filepath"

"github.com/h24w-17/extractor"
"github.com/spf13/cobra"
)

var rootCmd = &cobra.Command{
Use: "query-extractor",
Long: "Statistically analyze the codebase and extract SQL queries",
Args: cobra.ExactArgs(1),
ValidArgs: []string{"path"},
RunE: func(cmd *cobra.Command, args []string) error {
path := args[0]

out, err := cmd.Flags().GetString("out")
if err != nil {
return fmt.Errorf("error getting out flag: %v", err)
}

valid := extractor.IsValidDir(path)
if !valid {
return fmt.Errorf("invalid directory: %s", path)
}
files, err := extractor.ListAllGoFiles(path)
if err != nil {
return fmt.Errorf("error listing go files: %v", err)
}

fmt.Printf("found %d go files\n", len(files))
extractedQueries := []*extractor.ExtractedQuery{}
for _, file := range files {
relativePath, err := filepath.Rel(path, file)
if err != nil {
return fmt.Errorf("error getting relative path: %v", err)
}
queries, err := extractor.ExtractQueryFromFile(file, path)
if err != nil {
return fmt.Errorf("❌ %s: error while extracting: %v", relativePath, err)
}
fmt.Printf("✅ %s: %d queries extracted\n", relativePath, len(queries))
extractedQueries = append(extractedQueries, queries...)
}

err = extractor.WriteQueriesToFile(out, extractedQueries)
if err != nil {
return fmt.Errorf("error writing queries to file: %v", err)
}

fmt.Printf("queries written to %s\n", out)

return nil
},
}

func main() {
err := rootCmd.Execute()
if err != nil {
os.Exit(1)
}
}

func init() {
rootCmd.Flags().StringP("out", "o", "extracted.sql", "Destination file that extracted queries will be written to")
}
60 changes: 60 additions & 0 deletions extractor/extractor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package extractor

import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"path/filepath"
"regexp"
"strings"
)

var sqlPattern = regexp.MustCompile(`(?i)\b(SELECT|INSERT|UPDATE|DELETE)\b`)
var replacePattern = regexp.MustCompile(`\s+`)

type ExtractedQuery struct {
file string
pos int
content string
}

func ExtractQueryFromFile(path string, root string) ([]*ExtractedQuery, error) {
fs := token.NewFileSet()
node, err := parser.ParseFile(fs, path, nil, parser.AllErrors)
if err != nil {
return nil, fmt.Errorf("error parsing file: %v", err)
}

// 結果を収集
var results []*ExtractedQuery
ast.Inspect(node, func(n ast.Node) bool {
if n == nil {
return false
}
// 文字列リテラルを抽出
if lit, ok := n.(*ast.BasicLit); ok && lit.Kind == token.STRING {
// SQLクエリらしき文字列を抽出
value := strings.Trim(lit.Value, "\"`")
value = strings.ReplaceAll(value, "\n", " ")
value = replacePattern.ReplaceAllString(value, " ")
pos := lit.Pos()
if sqlPattern.MatchString(value) {
pos := fs.Position(pos)
relativePath, err := filepath.Rel(root, path)
if err != nil {
return false
}
results = append(results, &ExtractedQuery{
file: relativePath,
pos: pos.Line,
content: value,
})
}
return false
}
return true
})

return results, nil
}
47 changes: 47 additions & 0 deletions extractor/extractor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package extractor

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestExtractQueryFromFile(t *testing.T) {
tests := []struct {
name string
path string
root string
expected []*ExtractedQuery
}{
{
name: "extract query from file",
path: "testdata/extractor1.go",
root: "testdata",
expected: []*ExtractedQuery{
{file: "extractor1.go", pos: 32, content: "SELECT id, name FROM users"},
{file: "extractor1.go", pos: 44, content: "INSERT INTO users (name) VALUES (?)"},
{file: "extractor1.go", pos: 59, content: "SELECT id, name FROM users WHERE id = ?"},
{file: "extractor1.go", pos: 72, content: "UPDATE users SET name = ? WHERE id = ?"},
{file: "extractor1.go", pos: 79, content: "DELETE FROM users WHERE id = ?"},
{file: "extractor1.go", pos: 89, content: "SELECT id, user_id, title, body FROM posts"},
{file: "extractor1.go", pos: 103, content: "INSERT INTO posts (user_id, title, body) VALUES (?, ?, ?)"},
{file: "extractor1.go", pos: 118, content: "SELECT id, user_id, title, body FROM posts WHERE id = ?"},
{file: "extractor1.go", pos: 133, content: "UPDATE posts SET user_id = ?, title = ?, body = ? WHERE id = ?"},
{file: "extractor1.go", pos: 140, content: "DELETE FROM posts WHERE id = ?"},
{file: "extractor1.go", pos: 150, content: "SELECT id, post_id, body FROM comments"},
{file: "extractor1.go", pos: 163, content: "INSERT INTO comments (post_id, body) VALUES (?, ?)"},
{file: "extractor1.go", pos: 178, content: "SELECT id, post_id, body FROM comments WHERE id = ?"},
{file: "extractor1.go", pos: 191, content: "UPDATE comments SET body = ? WHERE id = ?"},
{file: "extractor1.go", pos: 198, content: "DELETE FROM comments WHERE id = ?"},
},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
actual, err := ExtractQueryFromFile(test.path, test.root)
assert.NoError(t, err)
assert.Equal(t, test.expected, actual)
})
}
}
58 changes: 58 additions & 0 deletions extractor/io.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package extractor

import (
"fmt"
"os"
"path/filepath"
"strings"
)

func IsValidDir(dir string) bool {
info, err := os.Stat(dir)
if err != nil {
return false
}
return info.IsDir()
}

func ListAllGoFiles(dir string) ([]string, error) {
var files []string
err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") {
return nil
}
files = append(files, path)
return nil
})
if err != nil {
return []string{}, fmt.Errorf("error walking directory: %v", err)
}
return files, nil
}

func WriteQueriesToFile(out string, queries []*ExtractedQuery) error {
f, err := os.Create(out)
if err != nil {
return fmt.Errorf("error creating file: %v", err)
}
defer f.Close()

for _, query := range queries {
_, err := f.WriteString(query.String() + "\n")
if err != nil {
return fmt.Errorf("error writing to file: %v", err)
}
}

return nil
}

func (q *ExtractedQuery) String() string {
return fmt.Sprintf("-- %s:%d\n%s", q.file, q.pos, q.content)
}
Loading

0 comments on commit 80c4268

Please sign in to comment.