diff --git a/binaries/binaries.go b/binaries/binaries.go index cf7c81df0..e0902b528 100644 --- a/binaries/binaries.go +++ b/binaries/binaries.go @@ -15,11 +15,13 @@ import ( ) // PrismaVersion is a hardcoded version of the Prisma CLI. -const PrismaVersion = "3.13.0" +//const PrismaVersion = "3.13.0" +var PrismaVersion = "3.13.0" // EngineVersion is a hardcoded version of the Prisma Engine. // The versions can be found under https://github.com/prisma/prisma-engines/commits/master -const EngineVersion = "efdf9b1183dddfd4258cd181a72125755215ab7b" +//const EngineVersion = "694eea289a8462c80264df36757e4fdc129b1b32" +var EngineVersion = "694eea289a8462c80264df36757e4fdc129b1b32" // PrismaURL points to an S3 bucket URL where the CLI binaries are stored. var PrismaURL = "https://prisma-photongo.s3-eu-west-1.amazonaws.com/%s-%s-%s-x64.gz" @@ -63,7 +65,7 @@ func PrismaCLIName() string { return fmt.Sprintf("prisma-cli-%s-x64", variation) } -var baseDirName = path.Join("prisma", "binaries") +var baseDirName = filepath.Join("prisma", "binaries") // GlobalTempDir returns the path of where the engines live // internally, this is the global temp dir @@ -71,11 +73,11 @@ func GlobalTempDir(version string) string { temp := os.TempDir() logger.Debug.Printf("temp dir: %s", temp) - return path.Join(temp, baseDirName, "engines", version) + return filepath.ToSlash(filepath.Join(temp, baseDirName, "engines", version)) } func GlobalUnpackDir(version string) string { - return path.Join(GlobalTempDir(version), "unpacked", "v2") + return filepath.ToSlash(filepath.Join(GlobalTempDir(version), "unpacked", "v2")) } // GlobalCacheDir returns the path of where the CLI lives @@ -88,13 +90,13 @@ func GlobalCacheDir() string { logger.Debug.Printf("global cache dir: %s", cache) - return path.Join(cache, baseDirName, "cli", PrismaVersion) + return filepath.ToSlash(filepath.Join(cache, baseDirName, "cli", PrismaVersion)) } func FetchEngine(toDir string, engineName string, binaryPlatformName string) error { logger.Debug.Printf("checking %s...", engineName) - to := platform.CheckForExtension(binaryPlatformName, path.Join(toDir, EngineVersion, fmt.Sprintf("prisma-%s-%s", engineName, binaryPlatformName))) + to := platform.CheckForExtension(binaryPlatformName, filepath.ToSlash(filepath.Join(toDir, EngineVersion, fmt.Sprintf("prisma-%s-%s", engineName, binaryPlatformName)))) binaryPlatformRemoteName := binaryPlatformName if binaryPlatformRemoteName == "linux" { @@ -145,7 +147,7 @@ func FetchNative(toDir string) error { func DownloadCLI(toDir string) error { cli := PrismaCLIName() - to := platform.CheckForExtension(platform.Name(), path.Join(toDir, cli)) + to := platform.CheckForExtension(platform.Name(), filepath.ToSlash(filepath.Join(toDir, cli))) url := platform.CheckForExtension(platform.Name(), fmt.Sprintf(PrismaURL, "prisma-cli", PrismaVersion, platform.Name())) logger.Debug.Printf("ensuring CLI %s from %s to %s", cli, to, url) @@ -166,7 +168,7 @@ func DownloadCLI(toDir string) error { } func GetEnginePath(dir, engine, binaryName string) string { - return platform.CheckForExtension(binaryName, path.Join(dir, EngineVersion, fmt.Sprintf("prisma-%s-%s", engine, binaryName))) + return platform.CheckForExtension(binaryName, filepath.ToSlash(filepath.Join(dir, EngineVersion, fmt.Sprintf("prisma-%s-%s", engine, binaryName)))) } func DownloadEngine(name string, toDir string) (file string, err error) { @@ -174,7 +176,7 @@ func DownloadEngine(name string, toDir string) (file string, err error) { logger.Debug.Printf("checking %s...", name) - to := platform.CheckForExtension(binaryName, path.Join(toDir, EngineVersion, fmt.Sprintf("prisma-%s-%s", name, binaryName))) + to := platform.CheckForExtension(binaryName, filepath.ToSlash(filepath.Join(toDir, EngineVersion, fmt.Sprintf("prisma-%s-%s", name, binaryName)))) url := platform.CheckForExtension(binaryName, fmt.Sprintf(EngineURL, EngineVersion, binaryName, name)) diff --git a/binaries/platform/platform.go b/binaries/platform/platform.go index 38bc24731..c99ada772 100644 --- a/binaries/platform/platform.go +++ b/binaries/platform/platform.go @@ -19,8 +19,16 @@ func BinaryPlatformName() string { } platform := Name() + // Refer to https://github.dev/prisma/prisma/tree/main/packages/cli/src/utils + // Not well test for !win & !mac + arch := Arch() if platform != "linux" { + if platform == "darwin" { + if arch == "arm64" { + return fmt.Sprintf("%s-%s", platform, arch) + } + } return platform } @@ -33,6 +41,11 @@ func BinaryPlatformName() string { ssl := getOpenSSL() name := fmt.Sprintf("%s-openssl-%s", distro, ssl) + if arch == "arm64" { + name = fmt.Sprintf("%s--arm64-openssl-%s", distro, ssl) + } else if arch == "arm" { + name = fmt.Sprintf("%s-arm-openssl-%s", distro, ssl) + } binaryNameWithSSLCache = name @@ -44,6 +57,10 @@ func Name() string { return runtime.GOOS } +func Arch() string { + return runtime.GOARCH +} + // CheckForExtension adds a .exe extension on windows (e.g. .gz -> .exe.gz) func CheckForExtension(platform, path string) string { if platform == "windows" { diff --git a/binaries/unpack/unpack.go b/binaries/unpack/unpack.go index 0c34017a8..bbd64cef4 100644 --- a/binaries/unpack/unpack.go +++ b/binaries/unpack/unpack.go @@ -3,7 +3,7 @@ package unpack import ( "fmt" "os" - "path" + "path/filepath" "time" "github.com/prisma/prisma-client-go/binaries" @@ -24,7 +24,7 @@ func Unpack(data []byte, name string, version string) { tempDir := binaries.GlobalUnpackDir(version) - dir := platform.CheckForExtension(platform.Name(), path.Join(tempDir, file)) + dir := platform.CheckForExtension(platform.Name(), filepath.ToSlash(filepath.ToSlash(filepath.Join(tempDir, file)))) if err := os.MkdirAll(tempDir, os.ModePerm); err != nil { panic(fmt.Errorf("mkdirall failed: %w", err)) diff --git a/cli/cli.go b/cli/cli.go index 89c31299d..51fb69086 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -2,13 +2,12 @@ package cli import ( "fmt" - "os" - "os/exec" - "path" - "github.com/prisma/prisma-client-go/binaries" "github.com/prisma/prisma-client-go/binaries/platform" "github.com/prisma/prisma-client-go/logger" + "os" + "os/exec" + "path/filepath" ) // Run the prisma CLI with given arguments @@ -25,9 +24,9 @@ func Run(arguments []string, output bool) error { prisma := binaries.PrismaCLIName() - logger.Debug.Printf("running %s %+v", path.Join(dir, prisma), arguments) + logger.Debug.Printf("running %s %+v", filepath.ToSlash(filepath.Join(dir, prisma)), arguments) - cmd := exec.Command(path.Join(dir, prisma), arguments...) //nolint:gosec + cmd := exec.Command(filepath.ToSlash(filepath.Join(dir, prisma)), arguments...) //nolint:gosec binaryName := platform.CheckForExtension(platform.Name(), platform.BinaryPlatformName()) cmd.Env = os.Environ() @@ -41,7 +40,7 @@ func Run(arguments []string, output bool) error { logger.Debug.Printf("overriding %s to %s", engine.Name, env) value = env } else { - value = path.Join(dir, binaries.EngineVersion, fmt.Sprintf("prisma-%s-%s", engine.Name, binaryName)) + value = filepath.ToSlash(filepath.Join(dir, binaries.EngineVersion, fmt.Sprintf("prisma-%s-%s", engine.Name, binaryName))) } cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", engine.Env, value)) diff --git a/engine/api.go b/engine/api.go new file mode 100644 index 000000000..c6e2399b6 --- /dev/null +++ b/engine/api.go @@ -0,0 +1,232 @@ +package engine + +import ( + "context" + "fmt" + "github.com/joho/godotenv" + "github.com/prisma/prisma-client-go/binaries" + "github.com/prisma/prisma-client-go/binaries/platform" + "github.com/prisma/prisma-client-go/engine/introspection" + "github.com/prisma/prisma-client-go/engine/migrate" + "github.com/prisma/prisma-client-go/generator/ast/dmmf" + "github.com/prisma/prisma-client-go/logger" + "os" + "os/exec" + "path/filepath" + "strings" + "time" +) + +func NewDMFQueryEngine(schema string) (*QueryEngine, error) { + content, err := Pull(schema) + if err != nil { + return nil, err + } + queryEngine := NewQueryEngine(content, false) + if err := queryEngine.ConnectSDK(); err != nil { + logger.Debug.Printf("connect fail err : ", err) + return nil, err + } + return queryEngine, nil +} + +var globalQueryEngine *QueryEngine + +func GetQueryEngineOnce(schema string) *QueryEngine { + if globalQueryEngine == nil { + content, err := Pull(schema) + if err != nil { + logger.Debug.Printf("connect fail err : ", err) + } + globalQueryEngine = NewQueryEngine(content, false) + if err := globalQueryEngine.ConnectSDK(); err != nil { + logger.Debug.Printf("connect fail err : ", err) + } + } + return globalQueryEngine +} + +func ReloadQueryEngineOnce(schema string) (*QueryEngine, error) { + // 先释放掉老的资源 + if globalQueryEngine != nil { + globalQueryEngine.Disconnect() + globalQueryEngine = nil + } + // 内省 + content, err := Pull(schema) + if err != nil { + logger.Debug.Printf("connect fail err : ", err) + return nil, err + } + + globalQueryEngine = NewQueryEngine(content, false) + if err := globalQueryEngine.ConnectSDK(); err != nil { + logger.Debug.Printf("connect fail err : ", err) + } + return globalQueryEngine, nil +} + +func Push(schemaPath string) error { + migrationEngine := migrate.NewMigrationEngine() + return migrationEngine.Push(schemaPath) +} + +func Pull(schema string) (string, error) { + migrationEngine := introspection.NewIntrospectEngine() + return migrationEngine.Pull(schema) +} + +func InitQueryEngine(schema string) error { + content, err := Pull(schema) + if err != nil { + logger.Debug.Printf("connect fail err : ", err) + return err + } + globalQueryEngine = NewQueryEngine(content, false) + if err := globalQueryEngine.ConnectSDK(); err != nil { + logger.Debug.Printf("connect fail err : ", err) + return err + } + return nil +} + +func QuerySchema(querySchema string, result interface{}) error { + ctx := context.TODO() + payload := GQLRequest{ + Query: querySchema, + Variables: map[string]interface{}{}, + } + err := globalQueryEngine.Do(ctx, payload, result) + if err != nil { + return err + } + return nil +} + +func QuerySDL(dbSchemaPath, sdlSchema string, result interface{}) error { + queryEngine := GetQueryEngineOnce(dbSchemaPath) + ctx := context.TODO() + payload := GQLRequest{ + Query: sdlSchema, + Variables: map[string]interface{}{}, + } + err := queryEngine.Do(ctx, payload, result) + if err != nil { + return err + } + return nil +} + +func QueryDMMF(dbSchemaPath string) (*dmmf.Document, error) { + queryEngine := GetQueryEngineOnce(dbSchemaPath) + return queryEngine.IntrospectDMMF(context.TODO()) +} + +func (e *QueryEngine) ensureSDK() (string, error) { + ensureEngine := time.Now() + + dir := binaries.GlobalCacheDir() + // 确保引擎一定下载了 + if err := binaries.FetchNative(dir); err != nil { + return "", fmt.Errorf("could not fetch binaries: %w", err) + } + binariesPath := filepath.ToSlash(filepath.Join(dir, binaries.EngineVersion)) + //binaryName := platform.CheckForExtension(platform.Name(), platform.Name()) + binaryName := platform.BinaryPlatformName() + + exactBinaryName := platform.CheckForExtension(platform.Name(), platform.BinaryPlatformName()) + + var file string + forceVersion := true + + name := "prisma-query-engine-" + localPath := filepath.ToSlash(filepath.Join("./", name+binaryName)) + localExactPath := filepath.ToSlash(filepath.Join("./", name+exactBinaryName)) + globalPath := filepath.ToSlash(filepath.Join(binariesPath, name+binaryName)) + globalExactPath := filepath.ToSlash(filepath.Join(binariesPath, name+exactBinaryName)) + + logger.Debug.Printf("expecting local query engine `%s` or `%s`", localPath, localExactPath) + logger.Debug.Printf("expecting global query engine `%s` or `%s`", globalPath, globalExactPath) + + // TODO write tests for all cases + + // first, check if the query engine binary is being overridden by PRISMA_QUERY_ENGINE_BINARY + prismaQueryEngineBinary := os.Getenv("PRISMA_QUERY_ENGINE_BINARY") + if prismaQueryEngineBinary != "" { + logger.Debug.Printf("PRISMA_QUERY_ENGINE_BINARY is defined, using %s", prismaQueryEngineBinary) + + if _, err := os.Stat(prismaQueryEngineBinary); err != nil { + return "", fmt.Errorf("PRISMA_QUERY_ENGINE_BINARY was provided, but no query engine was found at %s", prismaQueryEngineBinary) + } + + file = prismaQueryEngineBinary + forceVersion = false + } else { + if _, err := os.Stat(localExactPath); err == nil { + logger.Debug.Printf("exact query engine found in working directory") + file = localExactPath + } else if _, err := os.Stat(localPath); err == nil { + logger.Debug.Printf("query engine found in working directory") + file = localPath + } + + if _, err := os.Stat(globalExactPath); err == nil { + logger.Debug.Printf("query engine found in global path") + file = globalExactPath + } else if _, err := os.Stat(globalPath); err == nil { + logger.Debug.Printf("exact query engine found in global path") + file = globalPath + } + } + + if file == "" { + // TODO log instructions on how to fix this problem + return "", fmt.Errorf("no binary found ") + } + + startVersion := time.Now() + out, err := exec.Command(file, "--version").Output() + if err != nil { + return "", fmt.Errorf("version check failed: %w", err) + } + logger.Debug.Printf("version check took %s", time.Since(startVersion)) + + if v := strings.TrimSpace(strings.Replace(string(out), "query-engine", "", 1)); binaries.EngineVersion != v { + note := "Did you forget to run `go run github.com/prisma/prisma-client-go generate`?" + msg := fmt.Errorf("expected query engine version `%s` but got `%s`\n%s", binaries.EngineVersion, v, note) + if forceVersion { + return "", msg + } + + logger.Info.Printf("%s, ignoring since custom query engine was provided", msg) + } + + logger.Debug.Printf("using query engine at %s", file) + logger.Debug.Printf("ensure query engine took %s", time.Since(ensureEngine)) + + return file, nil +} + +func (e *QueryEngine) ConnectSDK() error { + logger.Debug.Printf("ensure query engine binary...") + + _ = godotenv.Load("e2e.env") + _ = godotenv.Load("db/e2e.env") + _ = godotenv.Load("prisma/e2e.env") + + startEngine := time.Now() + + file, err := e.ensureSDK() + if err != nil { + return fmt.Errorf("ensure: %w", err) + } + + if err := e.spawn(file); err != nil { + return fmt.Errorf("spawn: %w", err) + } + + logger.Debug.Printf("connecting took %s", time.Since(startEngine)) + logger.Debug.Printf("connected.") + + return nil +} diff --git a/engine/engine_factory.go b/engine/engine_factory.go new file mode 100644 index 000000000..ff62b4a30 --- /dev/null +++ b/engine/engine_factory.go @@ -0,0 +1,277 @@ +package engine + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "github.com/prisma/prisma-client-go/logger" + "github.com/vektah/gqlparser/v2/ast" + "github.com/vektah/gqlparser/v2/formatter" + "github.com/vektah/gqlparser/v2/parser" + "strconv" + "time" +) + +var prismaQueryEngineMap = map[int64]*QueryEngine{} + +type EngineFactory interface { + GetPrismaQueryEngine() (*QueryEngine, error) + ReloadPrismaQueryEngine() error + QuerySchema(param GQLRequest, result interface{}) error +} + +// QueryEngineFactory 创建工厂结构体并实现工厂接口 +type QueryEngineFactory struct { + Key int64 `json:"key"` + DBSchema string `json:"DBSchema"` +} + +func NewQueryEngineFactory(key int64, dbSchema string) QueryEngineFactory { + return QueryEngineFactory{ + Key: key, + DBSchema: dbSchema, + } +} + +func (q *QueryEngineFactory) GetPrismaQueryEngine() (*QueryEngine, error) { + // 如果不存在 + if _, ok := prismaQueryEngineMap[q.Key]; !ok { + // 创建 + content, err := Pull(q.DBSchema) + if err != nil { + logger.Debug.Printf("connect fail err : ", err) + return nil, err + } + queryEngine := NewQueryEngine(content, false) + if err := queryEngine.ConnectSDK(); err != nil { + logger.Debug.Printf("connect fail err : ", err) + return nil, err + } + prismaQueryEngineMap[q.Key] = queryEngine + } + return prismaQueryEngineMap[q.Key], nil +} + +func (q *QueryEngineFactory) ReloadPrismaQueryEngine() error { + // 先销毁旧的引擎 + if _, ok := prismaQueryEngineMap[q.Key]; ok { + prismaQueryEngineMap[q.Key].Disconnect() + } + + // 创建新引擎 + content, err := Pull(q.DBSchema) + if err != nil { + logger.Debug.Printf("connect fail err : ", err) + return err + } + queryEngine := NewQueryEngine(content, false) + if err := queryEngine.ConnectSDK(); err != nil { + logger.Debug.Printf("connect fail err : ", err) + return err + } + // 存入 + prismaQueryEngineMap[q.Key] = queryEngine + return nil +} + +func DisConnectEngine() { + for key, engine := range prismaQueryEngineMap { + engine.Disconnect() + delete(prismaQueryEngineMap, key) + } +} + +func (q *QueryEngineFactory) QuerySchema(param GQLRequest, result interface{}) error { + defer func() { + if err := recover(); err != nil { + logger.Debug.Println(err) + } + }() + ctx := context.TODO() + queryEngine, err := q.GetPrismaQueryEngine() + if err != nil { + return err + } + //err = queryEngine.DoQuery(ctx, param, result) + err = queryEngine.DoManyQuery(ctx, param, result) + if err != nil { + return err + } + return nil +} + +type GQLResult struct { + Data json.RawMessage `json:"data"` + Errors []GQLError `json:"errors"` + Extensions map[string]interface{} `json:"extensions"` +} + +type ErrResponse struct { + Errors []SQLErrResult `json:"errors"` +} + +type SQLErrResult struct { + Message string `json:"message"` + Path []string `json:"path"` + Locations []interface{} `json:"locations"` +} + +func NewSqlErrResult(errStr string) SQLErrResult { + return SQLErrResult{ + Message: errStr, + Path: make([]string, 0), + Locations: make([]interface{}, 0), + } +} +func (e *QueryEngine) DoManyQuery(ctx context.Context, payload GQLRequest, v interface{}) error { + queryObj, _ := parser.ParseQuery(&ast.Source{Input: payload.Query}) + resultErr := ErrResponse{ + Errors: make([]SQLErrResult, 0), + } + if len(queryObj.Operations) != 1 { + return fmt.Errorf("一次只能查询一个operation") + } + ope := queryObj.Operations[0] + + // 如果只有一个查询,则直接查询返回 + if len(ope.SelectionSet) == 1 { + onePayLoad := GQLRequest{ + Query: payload.Query, + Variables: map[string]interface{}{}, + } + return e.DoQuery(ctx, onePayLoad, v) + } + + // 多个查询 + requests := make([]GQLRequest, len(ope.SelectionSet)) + + selectionset := ope.SelectionSet + for i, selection := range selectionset { + ope.SelectionSet = ast.SelectionSet{selection} + requests[i] = GQLRequest{ + Query: FormatOperateionDocument(ope), + Variables: map[string]interface{}{}, + } + } + type GQLBatchResult struct { + Errors []GQLError `json:"errors"` + Result []GQLResult `json:"batchResult"` + } + var result GQLBatchResult + payloads := GQLBatchRequest{ + Batch: requests, + Transaction: true, + } + + if err := e.BatchReq(ctx, payloads, &result); err != nil { + // 如果出现错误,则将错误返回给前端 + resultErr.Errors = append(resultErr.Errors, NewSqlErrResult(err.Error())) + errBytes, _ := json.Marshal(resultErr) + if err := json.Unmarshal(errBytes, v); err != nil { + return fmt.Errorf("json unmarshal: %w", err) + } + return nil + } + if len(result.Errors) > 0 { + // 如果出现错误,则将错误返回给前端 + resultErr.Errors = append(resultErr.Errors, NewSqlErrResult(result.Errors[0].RawMessage())) + errBytes, _ := json.Marshal(resultErr) + if err := json.Unmarshal(errBytes, v); err != nil { + return fmt.Errorf("json unmarshal: %w", err) + } + return nil + } + // 合并JSON字符串 + var tmpRes string + for idx, inner := range result.Result { + if len(inner.Errors) > 0 { + // 如果出现错误,则直接将错误返回给前端 + resultErr.Errors = append(resultErr.Errors, NewSqlErrResult(result.Errors[0].RawMessage())) + errBytes, _ := json.Marshal(resultErr) + if err := json.Unmarshal(errBytes, v); err != nil { + return fmt.Errorf("json unmarshal: %w", err) + } + return nil + } + + str := string(inner.Data) + // 最后一条 + if idx == len(result.Result)-1 { + tmpRes = tmpRes + str[1:] // 删除开头的{ + } else { + // 非最后一条 + tmpRes = tmpRes + str[:len(str)-1] + "," // 删除结尾的} + } + } + resultStruct := struct { + Data interface{} `json:"data"` + }{} + if err := json.Unmarshal([]byte(tmpRes), &resultStruct.Data); err != nil { + return fmt.Errorf("json unmarshal: %w", err) + } + + resultBytes, _ := json.Marshal(resultStruct) + if err := json.Unmarshal(resultBytes, v); err != nil { + return fmt.Errorf("json unmarshal: %w", err) + } + return nil +} + +func (e *QueryEngine) DoQuery(ctx context.Context, payload GQLRequest, v interface{}) error { + startReq := time.Now() + + body, err := e.Request(ctx, "POST", "/", payload) + if err != nil { + return fmt.Errorf("request failed: %w", err) + } + + logger.Debug.Printf("[timing] query engine request took %s", time.Since(startReq)) + + if err := json.Unmarshal(body, v); err != nil { + return fmt.Errorf("json unmarshal: %w", err) + } + return nil +} + +func InterfaceToString(i interface{}) string { + switch v := i.(type) { + case string: + return v + case int: + return strconv.Itoa(v) + case bool: + return strconv.FormatBool(v) + default: + bytes, _ := json.Marshal(v) + return string(bytes) + } +} + +// 位置挪走 +func FormatOperateionDocument(operate *ast.OperationDefinition) string { + + query := &ast.QueryDocument{ + Operations: ast.OperationList{operate}, + } + var buf bytes.Buffer + formatter.NewFormatter(&buf).FormatQueryDocument(query) + + bufstr := buf.String() + + return bufstr +} + +// Do sends the http Request to the query engine and unmarshals the response +func (e *QueryEngine) BatchReq(ctx context.Context, payload interface{}, v interface{}) error { + body, err := e.Request(ctx, "POST", "/", payload) + if err != nil { + return fmt.Errorf("request failed: %w", err) + } + + if err := json.Unmarshal(body, &v); err != nil { + return fmt.Errorf("json unmarshal: %w", err) + } + + return nil +} diff --git a/engine/introspection/me.go b/engine/introspection/me.go new file mode 100644 index 000000000..b443297e6 --- /dev/null +++ b/engine/introspection/me.go @@ -0,0 +1,269 @@ +package introspection + +import ( + "fmt" + "github.com/prisma/prisma-client-go/binaries" + "github.com/prisma/prisma-client-go/binaries/platform" + "github.com/prisma/prisma-client-go/logger" + "io" + "log" + "os" + "os/exec" + "path/filepath" + "strings" + "time" +) + +import ( + "bufio" + "context" + "encoding/json" +) + +func InitEngine() { + +} + +func NewIntrospectEngine() *IntrospectEngine { + // TODO:这里可以设置默认值 + engine := &IntrospectEngine{ + // path: path, + } + file, err := engine.ensure() //确保引擎一定安装了 + if err != nil { + panic(err) + } + engine.path = file + return engine +} + +type IntrospectEngine struct { + path string +} + +func (e *IntrospectEngine) ensure() (string, error) { + ensureEngine := time.Now() + + dir := binaries.GlobalCacheDir() + // 确保引擎一定下载了 + if err := binaries.FetchNative(dir); err != nil { + return "", fmt.Errorf("could not fetch binaries: %w", err) + } + // check for darwin/windows/linux first + //binaryName := platform.CheckForExtension(platform.Name(), platform.Name()) + binaryName := platform.BinaryPlatformName() + if platform.Name() == "windows" { + binaryName = fmt.Sprintf("%s.exe", binaryName) + } + var file string + // forceVersion saves whether a version check should be done, which should be disabled + // when providing a custom query engine value + // forceVersion := true + name := "prisma-introspection-engine-" + globalPath := filepath.ToSlash(filepath.Join(dir, binaries.EngineVersion, name+binaryName)) + + logger.Debug.Printf("expecting global introspection engine `%s` ", globalPath) + + // TODO write tests for all cases + // first, check if the query engine binary is being overridden by PRISMA_MIGRATION_ENGINE_BINARY + prismaQueryEngineBinary := os.Getenv("PRISMA_INTROSPECTION_ENGINE_BINARY") + if prismaQueryEngineBinary != "" { + logger.Debug.Printf("PRISMA_INTROSPECTION_ENGINE_BINARY is defined, using %s", prismaQueryEngineBinary) + + if _, err := os.Stat(prismaQueryEngineBinary); err != nil { + return "", fmt.Errorf("PRISMA_INTROSPECTION_ENGINE_BINARY was provided, but no query engine was found at %s", prismaQueryEngineBinary) + } + + file = prismaQueryEngineBinary + // forceVersion = false + } else { + if _, err := os.Stat(globalPath); err == nil { + logger.Debug.Printf("exact query engine found in global path") + file = globalPath + } + } + + if file == "" { + // TODO log instructions on how to fix this problem + return "", fmt.Errorf("no binary found ") + } + logger.Debug.Printf("using introspection engine at %s", file) + logger.Debug.Printf("ensure introspection engine took %s", time.Since(ensureEngine)) + + return file, nil +} + +//func (e *IntrospectEngine) Pull(schema string) (string, error) { +// startParse := time.Now() +// +// ctx, cancel := context.WithTimeout(context.Background(), time.Second*600) +// defer cancel() +// +// cmd := exec.CommandContext(ctx, e.path) +// +// pipe, err := cmd.StdinPipe() // 标准输入流 +// if err != nil { +// return "", fmt.Errorf("introspect engine std in pipe %v", err.Error()) +// } +// defer pipe.Close() +// // 构建一个json-rpc 请求参数 +// req := IntrospectRequest{ +// Id: 1, +// Jsonrpc: "2.0", +// Method: "introspect", +// Params: []map[string]interface{}{ +// { +// "schema": string(schema), +// "compositeTypeDepth": -1, +// }, +// }, +// } +// +// data, err := json.Marshal(req) +// if err != nil { +// return "", err +// } +// // 入参追加到管道中 +// _, err = pipe.Write(append(data, []byte("\n")...)) +// if err != nil { +// // return "", err +// return "", err +// } +// +// out, err := cmd.StdoutPipe() +// if err != nil { +// err = fmt.Errorf("Introspect std out pipe %s ", err.Error()) +// } +// r := bufio.NewReader(out) +// +// // 开始执行 +// err = cmd.Start() +// if err != nil { +// return "", err +// } +// +// //var response IntrospectResponse +// var response IntrospectResponse +// +// outBuf := &bytes.Buffer{} +// // 这一段的意思是,每100ms读取一次结果,直到超时或有结果 +// for { +// // 等待100 ms +// //time.Sleep(time.Millisecond * 100) +// b, err := r.ReadByte() +// if err != nil { +// err = fmt.Errorf("Introspect ReadByte %s ", err.Error()) +// } +// err = outBuf.WriteByte(b) +// if err != nil { +// err = fmt.Errorf("IntrospectwriteByte %s ", err.Error()) +// } +// +// if b == '\n' { +// // 解析响应结果 +// err = json.Unmarshal(outBuf.Bytes(), &response) +// if err != nil { +// return "", err +// } +// if response.Error == nil { +// log.Println("introspect successful") +// } +// fmt.Print("ende ") +// break +// } +// // 如果超时了?跳出读取? +// if err := ctx.Err(); err != nil { +// return "", err +// } +// } +// log.Printf("[timing] introspect took %s", time.Since(startParse)) +// if response.Error != nil { +// return "", fmt.Errorf("introspect error: %s", response.Error.Data.Message) +// } +// dataModel := strings.Replace(response.Result.DataModel, " Bytes", " String", -1) +// //dataModel := strings.Replace(response.Result.DataModel, " Bytes", " String", -1) +// return dataModel, nil +//} + +func (e *IntrospectEngine) Pull(schema string) (string, error) { + startParse := time.Now() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*60) + defer cancel() // 读取一行数据后,发送kill信号 + + cmd := exec.CommandContext(ctx, e.path) + + pipe, err := cmd.StdinPipe() // 标准输入流 + if err != nil { + return "", fmt.Errorf("introspect engine std in pipe %v", err.Error()) + } + defer pipe.Close() + // 构建一个json-rpc 请求参数 + req := IntrospectRequest{ + Id: 1, + Jsonrpc: "2.0", + Method: "introspect", + Params: []map[string]interface{}{ + { + "schema": string(schema), + "compositeTypeDepth": -1, + }, + }, + } + + data, err := json.Marshal(req) + if err != nil { + return "", err + } + // 入参追加到管道中 + _, err = pipe.Write(append(data, []byte("\n")...)) + if err != nil { + return "", err + } + stdout, err := cmd.StdoutPipe() + + if err != nil { + log.Println(err) + return "", err + } + + // 不阻塞启动 + if err := cmd.Start(); err != nil { + return "", err + } + + // 使用cmd.wait关闭子进程 + go func() { + if err := cmd.Wait(); err != nil { + fmt.Sprintf("Child proess %d exit with err :%v \n", cmd.Process.Pid, err) + } + }() + + reader := bufio.NewReader(stdout) + + // TODO:如果一直堵死在这咋办? + //阻塞读取,实时读取输出流中的一行内容 + line, err2 := reader.ReadString('\n') + if err2 != nil || io.EOF == err2 { + return "", err2 + } + log.Println(line) + + var response IntrospectResponse + + // 解析响应结果 + err = json.Unmarshal([]byte(line), &response) + if err != nil { + return "", err + } + + log.Printf("[timing] introspect took %s", time.Since(startParse)) + if response.Error != nil { + return "", fmt.Errorf("introspect error: %s", response.Error.Data.Message) + } + log.Println("introspect successful") + + dataModel := strings.Replace(response.Result.DataModel, " Bytes", " String", -1) + //dataModel := strings.Replace(response.Result.DataModel, " Bytes", " String", -1) + return dataModel, nil +} diff --git a/engine/introspection/struct.go b/engine/introspection/struct.go new file mode 100644 index 000000000..7ebd4ecb3 --- /dev/null +++ b/engine/introspection/struct.go @@ -0,0 +1,49 @@ +package introspection + +//type IntrospectRequest struct { +// Id int `json:"id"` +// Jsonrpc string `json:"jsonrpc"` +// Method string `json:"method"` +// Params IntrospectRequestParams `json:"params"` +//} + +type IntrospectRequest struct { + Id int `json:"id"` + Jsonrpc string `json:"jsonrpc"` + Method string `json:"method"` + Params []map[string]interface{} `json:"params"` +} + +type IntrospectRequestParams struct { + CompositeTypeDepth int64 `json:"compositeTypeDepth"` + Schema string `json:"schema"` +} + +type IntrospectResponse struct { + Jsonrpc string `json:"jsonrpc"` + Result *IntrospectResponseResult `json:"result,omitempty"` + Error *IntrospectResponseError `json:"error,omitempty"` +} + +type IntrospectResponseResult struct { + ExecutedSteps int `json:"executedSteps"` + DataModel string `json:"dataModel"` + Marnings interface{} `json:"marnings"` + Version string `json:"version"` +} + +type IntrospectResponseError struct { + Code int `json:"code"` + Message string `json:"message"` + Data IntrospectResponseErrorData `json:"data"` +} + +type IntrospectResponseErrorData struct { + IsPanic bool `json:"is_panic"` + Message string `json:"message"` + Meta IntrospectResponseErrorDataMeta `json:"meta"` +} + +type IntrospectResponseErrorDataMeta struct { + FullError string `json:"full_error"` +} diff --git a/engine/lifecycle.go b/engine/lifecycle.go index 8659b408b..30800f40a 100644 --- a/engine/lifecycle.go +++ b/engine/lifecycle.go @@ -6,7 +6,7 @@ import ( "fmt" "os" "os/exec" - "path" + "path/filepath" "strings" "time" @@ -80,10 +80,10 @@ func (e *QueryEngine) ensure() (string, error) { forceVersion := true name := "prisma-query-engine-" - localPath := path.Join("./", name+binaryName) - localExactPath := path.Join("./", name+exactBinaryName) - globalPath := path.Join(binariesPath, name+binaryName) - globalExactPath := path.Join(binariesPath, name+exactBinaryName) + localPath := filepath.ToSlash(filepath.Join("./", name+binaryName)) + localExactPath := filepath.ToSlash(filepath.Join("./", name+exactBinaryName)) + globalPath := filepath.ToSlash(filepath.Join(binariesPath, name+binaryName)) + globalExactPath := filepath.ToSlash(filepath.Join(binariesPath, name+exactBinaryName)) logger.Debug.Printf("expecting local query engine `%s` or `%s`", localPath, localExactPath) logger.Debug.Printf("expecting global query engine `%s` or `%s`", globalPath, globalExactPath) diff --git a/engine/migrate/me.go b/engine/migrate/me.go new file mode 100644 index 000000000..be60f99f9 --- /dev/null +++ b/engine/migrate/me.go @@ -0,0 +1,199 @@ +package migrate + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io/ioutil" + "log" + "os" + "os/exec" + "path/filepath" + "time" + + "github.com/prisma/prisma-client-go/binaries" + "github.com/prisma/prisma-client-go/binaries/platform" + "github.com/prisma/prisma-client-go/logger" +) + +func NewMigrationEngine() *MigrationEngine { + // TODO:这里可以设置默认值 + engine := &MigrationEngine{ + // path: path, + } + file, err := engine.ensure() //确保引擎一定安装了 + if err != nil { + panic(err) + } + engine.path = file + return engine +} + +type MigrationEngine struct { + path string +} + +// func (e *MigrationEngine) Name() string { +// return "migration-engine" +// } + +func (e *MigrationEngine) ensure() (string, error) { + ensureEngine := time.Now() + + dir := binaries.GlobalCacheDir() + // 确保引擎一定下载了 + if err := binaries.FetchNative(dir); err != nil { + return "", fmt.Errorf("could not fetch binaries: %w", err) + } + // check for darwin/windows/linux first + //binaryName := platform.CheckForExtension(platform.Name(), platform.Name()) + binaryName := platform.BinaryPlatformName() + if platform.Name() == "windows" { + binaryName = fmt.Sprintf("%s.exe", binaryName) + } + + var file string + // forceVersion saves whether a version check should be done, which should be disabled + // when providing a custom query engine value + // forceVersion := true + + name := "prisma-migration-engine-" + globalPath := filepath.ToSlash(filepath.Join(dir, binaries.EngineVersion, name+binaryName)) + + logger.Debug.Printf("expecting global migration engine `%s` ", globalPath) + + // TODO write tests for all cases + + // first, check if the query engine binary is being overridden by PRISMA_MIGRATION_ENGINE_BINARY + prismaQueryEngineBinary := os.Getenv("PRISMA_MIGRATION_ENGINE_BINARY") + if prismaQueryEngineBinary != "" { + logger.Debug.Printf("PRISMA_MIGRATION_ENGINE_BINARY is defined, using %s", prismaQueryEngineBinary) + + if _, err := os.Stat(prismaQueryEngineBinary); err != nil { + return "", fmt.Errorf("PRISMA_MIGRATION_ENGINE_BINARY was provided, but no query engine was found at %s", prismaQueryEngineBinary) + } + + file = prismaQueryEngineBinary + // forceVersion = false + } else { + if _, err := os.Stat(globalPath); err == nil { + logger.Debug.Printf("exact query engine found in global path") + file = globalPath + } + } + + if file == "" { + // TODO log instructions on how to fix this problem + return "", fmt.Errorf("no binary found ") + } + + logger.Debug.Printf("using migration engine at %s", file) + logger.Debug.Printf("ensure migration engine took %s", time.Since(ensureEngine)) + + return file, nil +} + +func (e *MigrationEngine) Push(schemaPath string) error { + startParse := time.Now() + // 可以缓存到改引擎中? + schema, err := ioutil.ReadFile(schemaPath) + if err != nil { + err = fmt.Errorf("load prisma schema: %s", err.Error()) + return err + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*50) + defer cancel() + + cmd := exec.CommandContext(ctx, e.path, "--datamodel", schemaPath) + + pipe, err := cmd.StdinPipe() // 标准输入流 + if err != nil { + err = fmt.Errorf("migration engine std in pipe: %s", err.Error()) + return err + // return "", err + } + defer pipe.Close() + // 构建一个json-rpc 请求参数 + req := MigrationRequest{ + Id: 1, + Jsonrpc: "2.0", + Method: "schemaPush", + Params: MigrationRequestParams{ + Force: true, + Schema: string(schema), + }, + } + + data, err := json.Marshal(req) + if err != nil { + return err + } + // 入参追加到管道中 + _, err = pipe.Write(append(data, []byte("\n")...)) + if err != nil { + // return "", err + return err + } + + out, err := cmd.StdoutPipe() + if err != nil { + err = fmt.Errorf("migration std out pipe: %s", err.Error()) + } + r := bufio.NewReader(out) + + // 开始执行 + err = cmd.Start() + if err != nil { + return err + } + + // 使用cmd.wait关闭子进程 + go func() { + if err := cmd.Wait(); err != nil { + fmt.Sprintf("Child proess %d exit with err :%v \n", cmd.Process.Pid, err) + } + }() + + var response MigrationResponse + + outBuf := &bytes.Buffer{} + // {\"jsonrpc\":\"2.0\",\"result\":{\"executedSteps\":1,\"unexecutable\":[],\"warnings\":[]},\"id\":1}\n + // 这一段的意思是,每100ms读取一次结果,直到超时或有结果 + for { + // 等待100 ms + //time.Sleep(time.Millisecond * 100) + b, err := r.ReadByte() + if err != nil { + err = fmt.Errorf("migration ReadByte: %s", err.Error()) + } + err = outBuf.WriteByte(b) + if err != nil { + err = fmt.Errorf("migration writeByte: %s", err.Error()) + } + + if b == '\n' { + // 解析响应结果 + err = json.Unmarshal(outBuf.Bytes(), &response) + if err != nil { + return err + } + if response.Error == nil { + log.Println("Migration successful") + } + fmt.Print("ende ") + break + } + // 如果超时了?跳出读取? + if err := ctx.Err(); err != nil { + return err + } + } + log.Printf("[timing] migrate took %s", time.Since(startParse)) + if response.Error != nil { + return fmt.Errorf("migrate error: %s", response.Error.Data.Message) + } + return nil +} diff --git a/engine/migrate/struct.go b/engine/migrate/struct.go new file mode 100644 index 000000000..1a99ae525 --- /dev/null +++ b/engine/migrate/struct.go @@ -0,0 +1,39 @@ +package migrate + +type MigrationRequest struct { + Id int `json:"id"` + Jsonrpc string `json:"jsonrpc"` + Method string `json:"method"` + Params MigrationRequestParams `json:"params"` +} + +type MigrationRequestParams struct { + Force bool `json:"force"` + Schema string `json:"schema"` +} + +type MigrationResponse struct { + Jsonrpc string `json:"jsonrpc"` + Result *MigrationResponseResult `json:"result,omitempty"` + Error *MigrationResponseError `json:"error,omitempty"` +} + +type MigrationResponseResult struct { + ExecutedSteps int `json:"executedSteps"` +} + +type MigrationResponseError struct { + Code int `json:"code"` + Message string `json:"message"` + Data MigrationResponseErrorData `json:"data"` +} + +type MigrationResponseErrorData struct { + IsPanic bool `json:"is_panic"` + Message string `json:"message"` + Meta MigrationResponseErrorDataMeta `json:"meta"` +} + +type MigrationResponseErrorDataMeta struct { + FullError string `json:"full_error"` +} diff --git a/engine/proxy.go b/engine/proxy.go index 681ff9d63..7468f0764 100644 --- a/engine/proxy.go +++ b/engine/proxy.go @@ -9,7 +9,7 @@ import ( "fmt" "net/http" "net/url" - "path" + "path/filepath" "time" "github.com/prisma/prisma-client-go/binaries" @@ -188,5 +188,5 @@ func encodeSchema(schema string) string { } func getCloudURI(host, schemaHash string) string { - return "https://" + path.Join(host, binaries.PrismaVersion, schemaHash) + return "https://" + filepath.ToSlash(filepath.Join(host, binaries.PrismaVersion, schemaHash)) } diff --git a/engine/request.go b/engine/request.go index 6a033e5cf..ac57b7aa1 100644 --- a/engine/request.go +++ b/engine/request.go @@ -7,6 +7,7 @@ import ( "net/http" "time" + "github.com/prisma/prisma-client-go/generator/ast/dmmf" "github.com/prisma/prisma-client-go/logger" "github.com/prisma/prisma-client-go/runtime/types" ) @@ -84,3 +85,42 @@ func (e *QueryEngine) Request(ctx context.Context, method string, path string, p req.Header.Set("content-type", "application/json") }) } + +func (e *QueryEngine) IntrospectDMMF(ctx context.Context) (*dmmf.Document, error) { + startReq := time.Now() + body, err := e.Request(ctx, "GET", "/dmmf", nil) + if err != nil { + logger.Info.Printf("dmmf request failed: %s", err) + return nil, err + } + + logger.Debug.Printf("[timing] query engine dmmf request took %s", time.Since(startReq)) + + startParse := time.Now() + + var response dmmf.Document + if err := json.Unmarshal(body, &response); err != nil { + logger.Info.Printf("json unmarshal: %s", err) + + return nil, err + } + + logger.Debug.Printf("[timing] request unmarshaling took %s", time.Since(startParse)) + + return &response, nil +} + +func (e *QueryEngine) IntrospectSDL(ctx context.Context) ([]byte, error) { + + startReq := time.Now() + + body, err := e.Request(ctx, "GET", "/sdl", nil) + if err != nil { + logger.Info.Printf("sdl request failed: %s", err) + return nil, err + } + + logger.Debug.Printf("[timing] query engine sdl request took %s", time.Since(startReq)) + + return body, nil +} diff --git a/engine/utils.go b/engine/utils.go new file mode 100644 index 000000000..a31d5bf19 --- /dev/null +++ b/engine/utils.go @@ -0,0 +1,77 @@ +package engine + +import ( + "bytes" + "encoding/json" + "regexp" + "strings" + + "github.com/vektah/gqlparser/v2/ast" + "github.com/vektah/gqlparser/v2/formatter" + "github.com/vektah/gqlparser/v2/gqlerror" + "github.com/vektah/gqlparser/v2/parser" +) + +func InlineQueryDocument(query *ast.QueryDocument, variable map[string]interface{}) (string, error) { + // 这里要循环处理,去除变量输入 + for _, operation := range query.Operations { + operation.VariableDefinitions = ast.VariableDefinitionList{} + } + + var buf bytes.Buffer + formatter.NewFormatter(&buf).FormatQueryDocument(query) + + bufstr := buf.String() + + for k, v := range variable { + // s, _ := json.MarshalIndent(v, "", "\t") // TODO:这里去掉引号 + s, _ := json.Marshal(v) // TODO:这里去掉引号 + ss := convert(string(s)) + bufstr = strings.ReplaceAll(bufstr, "$"+k, ss) + } + return bufstr, nil +} + +func InlineQuery(str string, variable map[string]interface{}) (string, error) { + query, err := parser.ParseQuery(&ast.Source{Input: str}) + if err != nil { + gqlErr := err.(*gqlerror.Error) + return "", gqlerror.List{gqlErr} + } + + return InlineQueryDocument(query, variable) +} + +// https://www.cnblogs.com/vicF/p/9517960.html +// {"id":{"equals":"ssss"}}=>{id:{equals:"ssss"}} +// ["id","id"]=>["id","id"] +func convert(s string) string { + reg := regexp.MustCompile("\"(\\w+)\"(\\s*:\\s*)") + res := reg.ReplaceAllString(s, "$1$2") + + return res +} + +// func convert2(s string) string { +// var b bytes.Buffer +// shouldSkip := true + +// for i := 0; i < len(s); i++ { +// c := string(s[i]) +// if c == `"` { +// if i > 0 && string(s[i-1]) == `:` { +// shouldSkip = false +// b.WriteString(c) +// continue +// } +// if shouldSkip { +// continue +// } +// shouldSkip = true +// } + +// b.WriteString(c) +// } + +// return b.String() +// } diff --git a/example/db/.gitignore b/example/db/.gitignore new file mode 100644 index 000000000..a0c7514ad --- /dev/null +++ b/example/db/.gitignore @@ -0,0 +1,2 @@ +# gitignore generated by Prisma Client Go. DO NOT EDIT. +*_gen.go diff --git a/example/go.mod b/example/go.mod new file mode 100644 index 000000000..1eba57d38 --- /dev/null +++ b/example/go.mod @@ -0,0 +1,13 @@ +module demo + +go 1.18 + +require ( + github.com/iancoleman/strcase v0.0.0-20190422225806-e506e3ef7365 + github.com/joho/godotenv v1.4.0 + github.com/prisma/prisma-client-go v0.16.2 + github.com/shopspring/decimal v1.3.1 + github.com/takuoki/gocase v1.0.0 +) + +replace github.com/prisma/prisma-client-go => ../ diff --git a/example/go.sum b/example/go.sum new file mode 100644 index 000000000..07e1cdb08 --- /dev/null +++ b/example/go.sum @@ -0,0 +1,24 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/iancoleman/strcase v0.0.0-20190422225806-e506e3ef7365 h1:ECW73yc9MY7935nNYXUkK7Dz17YuSUI9yqRqYS8aBww= +github.com/iancoleman/strcase v0.0.0-20190422225806-e506e3ef7365/go.mod h1:SK73tn/9oHe+/Y0h39VT4UCxmurVJkR5NA7kMEAOgSE= +github.com/joho/godotenv v1.4.0 h1:3l4+N6zfMWnkbPEXKng2o2/MR5mSwTrBih4ZEkkz1lg= +github.com/joho/godotenv v1.4.0/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prisma/prisma-client-go v0.16.2 h1:Yh8pHJhjicJ3Nw66X6S5/Rfp2PZQ5r5j38fd4k5Qxa8= +github.com/prisma/prisma-client-go v0.16.2/go.mod h1:B1QEQQo4TLV9NzzrtOvW7pz4yOKXlxwMY0tKQivsdOU= +github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= +github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/takuoki/gocase v1.0.0 h1:gPwLJTWVm2T1kUiCsKirg/faaIUGVTI0FA3SYr75a44= +github.com/takuoki/gocase v1.0.0/go.mod h1:QgOKJrbuJoDrtoKswBX1/Dw8mJrkOV9tbQZJaxaJ6zc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/example/main.go b/example/main.go new file mode 100644 index 000000000..e4d984412 --- /dev/null +++ b/example/main.go @@ -0,0 +1,218 @@ +package main + +import ( + "context" + "fmt" + "github.com/prisma/prisma-client-go/engine" + "io/ioutil" +) + +const schmea = ` +datasource db { + // could be postgresql or mysql + provider = "sqlite" + url = "file:dev.db" + } + + generator db { + provider = "go run github.com/prisma/prisma-client-go" + // set the output folder and package name + // output = "./your-folder" + // package = "yourpackagename" + } + + model Post { + id String @default(cuid()) @id + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + title String + published Boolean + desc String? + } +` + +const mysqlSchema = `generator db { + provider = "go run github.com/prisma/prisma-client-go" +} + +datasource db { + provider = "mysql" + url = "mysql://root:shaoxiong123456@8.142.115.204:3306/main" +} + +model oauth_user { + id String @id @db.VarChar(50) + name String? @default("") @db.VarChar(50) + nick_name String? @default("") @db.VarChar(50) + user_name String? @unique(map: "name_index") @default("") @db.VarChar(50) + encryption_password String? @default("") @db.VarChar(250) + mobile String? @default("") @db.VarChar(11) + email String? @default("") @db.VarChar(50) + mate_data String? @db.Text + last_login_time DateTime? @db.Timestamp(0) + status Int? @default(0) @db.TinyInt + create_time DateTime? @default(now()) @db.Timestamp(0) + update_time DateTime? @db.Timestamp(0) + is_del Int? @default(0) @db.UnsignedTinyInt +} +` + +func main() { + // if err := run(); err != nil { + // panic(err) + // } + //migrationEngine := migrate.NewMigrationEngine() + // + //migrationEngine.Push("schema2.prisma") + //migrationEngine.Push2("schema1.prisma") + //migrationEngine.Push("schema2.prisma") + //migrationEngine.Push2("schema2.prisma") + + //introspectionEngine := introspection.NewIntrospectEngine() + //introspectionEngine.Pull("schema1.prisma") + //ntrospectionEngine.Pull2("schema1.prisma") + //introspectionEngine.Pull("schema1.prisma") + //introspectionEngine.Pull2("schema1.prisma") + ss, _ := ioutil.ReadFile("schema1.prisma") + engine.Push(string(ss)) + //engine.Pull("schema2.prisma") + // testDmmf() + //engine.QueryDMMF(mysqlSchema) + //testSdl1() +} + +func testDmmf() { + engine := engine.NewQueryEngine(schmea, false) + defer engine.Disconnect() + if err := engine.Connect(); err != nil { + panic(err) + } + dmmf, err := engine.IntrospectDMMF(context.TODO()) + if err != nil { + panic(err) + } + fmt.Println(dmmf.Datamodel) +} + +var querySchema = `{ result:findFirstoauth_user {id name nick_name}}` + +type OauthUser struct { + ID string `json:"id"` // id + Name string `json:"name"` // 姓名 + NickName string `json:"nick_name"` // 昵称 + UserName string `json:"user_name"` // 用户名 + EncryptionPassword string `json:"encryption_password"` // 加密后密码 + Mobile string `json:"mobile"` // 手机号 + Email string `json:"email"` // 邮箱 + LastLoginTime string `json:"last_login_time"` // 最后一次登陆时间 + Status int64 `json:"status"` // 状态 + MateData string `json:"mate_data"` // 其他信息(json字符串保存) + CreateTime string `json:"create_time"` // 创建时间 + UpdateTime string `json:"update_time"` // 修改时间 + IsDel int64 `json:"isDel"` // 是否删除 +} + +//func testSdl1() { +// +// queryEngine := engine.GetQueryEngineOnce(mysqlSchema) +// ctx := context.TODO() +// //var result OauthUser +// +// var response OauthUser +// payload := engine.GQLRequest{ +// Query: querySchema, +// Variables: map[string]interface{}{}, +// } +// err := queryEngine.Do(ctx, payload, &response) +// //result, err := engine.Do(ctx, querySchema) +// //fmt.Print(result) +// if err != nil { +// panic(err) +// } +//} + +func testSdl() { + engine := engine.NewQueryEngine(mysqlSchema, false) + defer engine.Disconnect() + if err := engine.Connect(); err != nil { + panic(err) + } + ctx := context.TODO() + var result OauthUser + err := engine.Do(ctx, querySchema, result) + //sdl, err := engine.IntrospectSDL(ctx) + + if err != nil { + panic(err) + } + //fmt.Println(string(sdl)) +} + +// func run() error { +// client := db.NewClient() +// if err := client.Prisma.Connect(); err != nil { +// return err +// } + +// defer func() { +// if err := client.Prisma.Disconnect(); err != nil { +// panic(err) +// } +// }() + +// ctx := context.Background() + +// // create a post +// createdPost, err := client.Post.CreateOne( +// db.Post.Title.Set("Hi from Prisma!"), +// db.Post.Published.Set(true), +// db.Post.Desc.Set("Prisma is a database toolkit and makes databases easy."), +// ).Exec(ctx) +// if err != nil { +// return err +// } + +// result, _ := json.MarshalIndent(createdPost, "", " ") +// fmt.Printf("created post: %s\n", result) + +// // find a single post +// post, err := client.Post.FindUnique( +// db.Post.ID.Equals(createdPost.ID), +// ).Exec(ctx) +// if err != nil { +// return err +// } + +// result, _ = json.MarshalIndent(post, "", " ") +// fmt.Printf("post: %s\n", result) + +// // for optional/nullable values, you need to check the function and create two return values +// // `desc` is a string, and `ok` is a bool whether the record is null or not. If it's null, +// // `ok` is false, and `desc` will default to Go's default values; in this case an empty string (""). Otherwise, +// // `ok` is true and `desc` will be "my description". +// desc, ok := post.Desc() +// if !ok { +// return fmt.Errorf("post's description is null") +// } + +// fmt.Printf("The posts's description is: %s\n", desc) + +// createUserA := client.Post.CreateOne( +// db.Post.Title.Set("2"), +// db.Post.Published.Set(true), +// db.Post.Desc.Set("222."), +// ).Tx() + +// createUserB := client.Post.CreateOne( +// db.Post.Title.Set("3"), +// db.Post.Published.Set(true), +// db.Post.Desc.Set("222."), +// ).Tx() + +// tx := client.Prisma.Transaction(createUserA, createUserB) +// if err := tx.Exec(ctx); err != nil { +// panic(err) +// } + +// return nil +// } diff --git a/example/schema.prisma b/example/schema.prisma new file mode 100644 index 000000000..0ea9ac540 --- /dev/null +++ b/example/schema.prisma @@ -0,0 +1,30 @@ +datasource db { + // could be postgresql or mysql + provider = "sqlite" + url = "file:dev.db" +} + +generator db { + provider = "go run github.com/prisma/prisma-client-go" + // set the output folder and package name + // output = "./your-folder" + // package = "yourpackagename" +} + +model Post { + id String @id @default(cuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + title String + published Boolean + desc String? +} + +model User { + id String @id @default(cuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + name String + gender Boolean + desc String? +} diff --git a/example/schema1.prisma b/example/schema1.prisma new file mode 100644 index 000000000..272042452 --- /dev/null +++ b/example/schema1.prisma @@ -0,0 +1,39 @@ +datasource db { + // could be postgresql or mysql + provider = "mysql" + url = "mysql://root:shaoxiong123456@8.142.115.204:3306/wunder-demo" +} + +generator db { + provider = "go run github.com/prisma/prisma-client-go" + // set the output folder and package name + // output = "./your-folder" + // package = "yourpackagename" +} + +model Post { + id String @id @default(cuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + title String + published Boolean + desc String? +} +model oauth_user { + id String @id @db.VarChar(50) + name String? @default("") @db.VarChar(50) + nick_name String? @default("") @db.VarChar(50) + user_name String? @unique(map: "name_index") @default("") @db.VarChar(50) + encryption_password String? @default("") @db.VarChar(250) + mobile String? @default("") @db.VarChar(11) + email String? @default("") @db.VarChar(50) + mate_data String? @db.Text + last_login_time DateTime? @db.Timestamp(0) + status Int? @default(0) @db.TinyInt + create_time DateTime? @default(now()) @db.Timestamp(0) + update_time DateTime? @db.Timestamp(0) + is_del Int? @default(0) @db.UnsignedTinyInt +} + + + diff --git a/example/schema2.prisma b/example/schema2.prisma new file mode 100644 index 000000000..627ac7197 --- /dev/null +++ b/example/schema2.prisma @@ -0,0 +1,24 @@ +generator db { + provider = "go run github.com/prisma/prisma-client-go" +} + +datasource db { + provider = "mysql" + url = "mysql://root:shaoxiong123456@8.142.115.204:3306/wunder-demo" +} + +model oauth_user { + id String @id @db.VarChar(50) + name String? @default("") @db.VarChar(50) + nick_name String? @default("") @db.VarChar(50) + user_name String? @unique(map: "name_index") @default("") @db.VarChar(50) + encryption_password String? @default("") @db.VarChar(250) + mobile String? @default("") @db.VarChar(11) + email String? @default("") @db.VarChar(50) + mate_data String? @db.Text + last_login_time DateTime? @db.Timestamp(0) + status Int? @default(0) @db.TinyInt + create_time DateTime? @default(now()) @db.Timestamp(0) + update_time DateTime? @db.Timestamp(0) + is_del Int? @default(0) @db.UnsignedTinyInt +} diff --git a/generator.go b/generator.go index 89632ca98..1918f1839 100644 --- a/generator.go +++ b/generator.go @@ -5,14 +5,13 @@ import ( "encoding/json" "errors" "fmt" - "io" - "log" - "os" - "path" - "github.com/prisma/prisma-client-go/generator" "github.com/prisma/prisma-client-go/jsonrpc" "github.com/prisma/prisma-client-go/logger" + "io" + "log" + "os" + "path/filepath" ) var writeDebugFile = os.Getenv("PRISMA_CLIENT_GO_WRITE_DMMF_FILE") != "" @@ -68,7 +67,7 @@ func invokePrisma() error { case "getManifest": response = jsonrpc.ManifestResponse{ Manifest: jsonrpc.Manifest{ - DefaultOutput: path.Join(".", "db"), + DefaultOutput: filepath.ToSlash(filepath.Join(".", "db")), PrettyName: "Prisma Client Go", }, } diff --git a/generator/run.go b/generator/run.go index b7a1652bf..775740eed 100644 --- a/generator/run.go +++ b/generator/run.go @@ -7,7 +7,6 @@ import ( "go/build" "go/format" "os" - "path" "path/filepath" "strings" "text/template" @@ -37,7 +36,7 @@ func Run(input *Root) error { if err := os.MkdirAll(input.Generator.Output.Value, os.ModePerm); err != nil { return fmt.Errorf("could not create output directory: %w", err) } - if err := os.WriteFile(path.Join(input.Generator.Output.Value, ".gitignore"), []byte(gitignore), 0644); err != nil { + if err := os.WriteFile(filepath.Join(input.Generator.Output.Value, ".gitignore"), []byte(gitignore), 0644); err != nil { return fmt.Errorf("could not write .gitignore: %w", err) } } @@ -124,7 +123,7 @@ func generateClient(input *Root) error { } // TODO make this configurable - outFile := path.Join(output, "db_gen.go") + outFile := filepath.Join(output, "db_gen.go") if err := os.WriteFile(outFile, formatted, 0644); err != nil { return fmt.Errorf("could not write template data to file writer %s: %w", outFile, err) } @@ -184,7 +183,7 @@ func generateQueryEngineFiles(binaryTargets []string, pkg, outputDir string) err } filename := fmt.Sprintf("query-engine-%s_gen.go", name) - to := path.Join(outputDir, filename) + to := filepath.Join(outputDir, filename) // TODO check if already exists, but make sure version matches if err := bindata.WriteFile(strings.ReplaceAll(name, "-", "_"), pkg, pt, enginePath, to); err != nil { diff --git a/go.mod b/go.mod index 82416434c..d44a8343a 100644 --- a/go.mod +++ b/go.mod @@ -8,4 +8,5 @@ require ( github.com/shopspring/decimal v1.3.1 github.com/stretchr/testify v1.8.0 github.com/takuoki/gocase v1.0.0 + github.com/vektah/gqlparser/v2 v2.5.1 ) diff --git a/go.sum b/go.sum index 6eaad2894..47e244d5b 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +github.com/agnivade/levenshtein v1.0.1 h1:3oJU7J3FGFmyhn8KHjmVaZCN5hxTr7GxgRue+sxIXdQ= +github.com/agnivade/levenshtein v1.0.1/go.mod h1:CURSv5d9Uaml+FovSIICkLbAUZ9S4RqaHDIsdSBg7lM= +github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNgfBlViaCIJKLlCJ6/fmUseuG0wVQ= +github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -5,19 +9,34 @@ github.com/iancoleman/strcase v0.0.0-20190422225806-e506e3ef7365 h1:ECW73yc9MY79 github.com/iancoleman/strcase v0.0.0-20190422225806-e506e3ef7365/go.mod h1:SK73tn/9oHe+/Y0h39VT4UCxmurVJkR5NA7kMEAOgSE= github.com/joho/godotenv v1.4.0 h1:3l4+N6zfMWnkbPEXKng2o2/MR5mSwTrBih4ZEkkz1lg= github.com/joho/godotenv v1.4.0/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0= +github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/takuoki/gocase v1.0.0 h1:gPwLJTWVm2T1kUiCsKirg/faaIUGVTI0FA3SYr75a44= github.com/takuoki/gocase v1.0.0/go.mod h1:QgOKJrbuJoDrtoKswBX1/Dw8mJrkOV9tbQZJaxaJ6zc= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +github.com/vektah/gqlparser/v2 v2.5.1 h1:ZGu+bquAY23jsxDRcYpWjttRZrUz07LbiY77gUOHcr4= +github.com/vektah/gqlparser/v2 v2.5.1/go.mod h1:mPgqFBu/woKTVYWyNk8cO3kh4S/f4aRFZrvOnp3hmCs= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=