Skip to content

Commit

Permalink
Extract feature detection package (cli#5494)
Browse files Browse the repository at this point in the history
  • Loading branch information
samcoe authored May 17, 2022
1 parent f8b3ff9 commit 539b150
Show file tree
Hide file tree
Showing 9 changed files with 492 additions and 488 deletions.
280 changes: 0 additions & 280 deletions api/queries_pr.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,12 @@ import (
"context"
"fmt"
"net/http"
"strings"
"time"

"github.com/cli/cli/v2/internal/ghinstance"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/set"
"github.com/shurcooL/githubv4"
"golang.org/x/sync/errgroup"
)

type PullRequestsPayload struct {
ViewerCreated PullRequestAndTotalCount
ReviewRequested PullRequestAndTotalCount
CurrentPR *PullRequest
DefaultBranch string
}

type PullRequestAndTotalCount struct {
TotalCount int
PullRequests []PullRequest
Expand Down Expand Up @@ -269,275 +258,6 @@ func (pr *PullRequest) DisplayableReviews() PullRequestReviews {
return PullRequestReviews{Nodes: published, TotalCount: len(published)}
}

type pullRequestFeature struct {
HasReviewDecision bool
HasStatusCheckRollup bool
HasBranchProtectionRule bool
}

func determinePullRequestFeatures(httpClient *http.Client, hostname string) (prFeatures pullRequestFeature, err error) {
if !ghinstance.IsEnterprise(hostname) {
prFeatures.HasReviewDecision = true
prFeatures.HasStatusCheckRollup = true
prFeatures.HasBranchProtectionRule = true
return
}

var featureDetection struct {
PullRequest struct {
Fields []struct {
Name string
} `graphql:"fields(includeDeprecated: true)"`
} `graphql:"PullRequest: __type(name: \"PullRequest\")"`
Commit struct {
Fields []struct {
Name string
} `graphql:"fields(includeDeprecated: true)"`
} `graphql:"Commit: __type(name: \"Commit\")"`
}

// needs to be a separate query because the backend only supports 2 `__type` expressions in one query
var featureDetection2 struct {
Ref struct {
Fields []struct {
Name string
} `graphql:"fields(includeDeprecated: true)"`
} `graphql:"Ref: __type(name: \"Ref\")"`
}

v4 := graphQLClient(httpClient, hostname)

g := new(errgroup.Group)
g.Go(func() error {
return v4.QueryNamed(context.Background(), "PullRequest_fields", &featureDetection, nil)
})
g.Go(func() error {
return v4.QueryNamed(context.Background(), "PullRequest_fields2", &featureDetection2, nil)
})

err = g.Wait()
if err != nil {
return
}

for _, field := range featureDetection.PullRequest.Fields {
switch field.Name {
case "reviewDecision":
prFeatures.HasReviewDecision = true
}
}
for _, field := range featureDetection.Commit.Fields {
switch field.Name {
case "statusCheckRollup":
prFeatures.HasStatusCheckRollup = true
}
}
for _, field := range featureDetection2.Ref.Fields {
switch field.Name {
case "branchProtectionRule":
prFeatures.HasBranchProtectionRule = true
}
}
return
}

type StatusOptions struct {
CurrentPR int
HeadRef string
Username string
Fields []string
}

func PullRequestStatus(client *Client, repo ghrepo.Interface, options StatusOptions) (*PullRequestsPayload, error) {
type edges struct {
TotalCount int
Edges []struct {
Node PullRequest
}
}

type response struct {
Repository struct {
DefaultBranchRef struct {
Name string
}
PullRequests edges
PullRequest *PullRequest
}
ViewerCreated edges
ReviewRequested edges
}

var fragments string
if len(options.Fields) > 0 {
fields := set.NewStringSet()
fields.AddValues(options.Fields)
// these are always necessary to find the PR for the current branch
fields.AddValues([]string{"isCrossRepository", "headRepositoryOwner", "headRefName"})
gr := PullRequestGraphQL(fields.ToSlice())
fragments = fmt.Sprintf("fragment pr on PullRequest{%s}fragment prWithReviews on PullRequest{...pr}", gr)
} else {
var err error
fragments, err = pullRequestFragment(client.http, repo.RepoHost())
if err != nil {
return nil, err
}
}

queryPrefix := `
query PullRequestStatus($owner: String!, $repo: String!, $headRefName: String!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) {
repository(owner: $owner, name: $repo) {
defaultBranchRef {
name
}
pullRequests(headRefName: $headRefName, first: $per_page, orderBy: { field: CREATED_AT, direction: DESC }) {
totalCount
edges {
node {
...prWithReviews
}
}
}
}
`
if options.CurrentPR > 0 {
queryPrefix = `
query PullRequestStatus($owner: String!, $repo: String!, $number: Int!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) {
repository(owner: $owner, name: $repo) {
defaultBranchRef {
name
}
pullRequest(number: $number) {
...prWithReviews
baseRef {
branchProtectionRule {
requiredApprovingReviewCount
}
}
}
}
`
}

query := fragments + queryPrefix + `
viewerCreated: search(query: $viewerQuery, type: ISSUE, first: $per_page) {
totalCount: issueCount
edges {
node {
...prWithReviews
}
}
}
reviewRequested: search(query: $reviewerQuery, type: ISSUE, first: $per_page) {
totalCount: issueCount
edges {
node {
...pr
}
}
}
}
`

currentUsername := options.Username
if currentUsername == "@me" && ghinstance.IsEnterprise(repo.RepoHost()) {
var err error
currentUsername, err = CurrentLoginName(client, repo.RepoHost())
if err != nil {
return nil, err
}
}

viewerQuery := fmt.Sprintf("repo:%s state:open is:pr author:%s", ghrepo.FullName(repo), currentUsername)
reviewerQuery := fmt.Sprintf("repo:%s state:open review-requested:%s", ghrepo.FullName(repo), currentUsername)

currentPRHeadRef := options.HeadRef
branchWithoutOwner := currentPRHeadRef
if idx := strings.Index(currentPRHeadRef, ":"); idx >= 0 {
branchWithoutOwner = currentPRHeadRef[idx+1:]
}

variables := map[string]interface{}{
"viewerQuery": viewerQuery,
"reviewerQuery": reviewerQuery,
"owner": repo.RepoOwner(),
"repo": repo.RepoName(),
"headRefName": branchWithoutOwner,
"number": options.CurrentPR,
}

var resp response
err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}

var viewerCreated []PullRequest
for _, edge := range resp.ViewerCreated.Edges {
viewerCreated = append(viewerCreated, edge.Node)
}

var reviewRequested []PullRequest
for _, edge := range resp.ReviewRequested.Edges {
reviewRequested = append(reviewRequested, edge.Node)
}

var currentPR = resp.Repository.PullRequest
if currentPR == nil {
for _, edge := range resp.Repository.PullRequests.Edges {
if edge.Node.HeadLabel() == currentPRHeadRef {
currentPR = &edge.Node
break // Take the most recent PR for the current branch
}
}
}

payload := PullRequestsPayload{
ViewerCreated: PullRequestAndTotalCount{
PullRequests: viewerCreated,
TotalCount: resp.ViewerCreated.TotalCount,
},
ReviewRequested: PullRequestAndTotalCount{
PullRequests: reviewRequested,
TotalCount: resp.ReviewRequested.TotalCount,
},
CurrentPR: currentPR,
DefaultBranch: resp.Repository.DefaultBranchRef.Name,
}

return &payload, nil
}

func pullRequestFragment(httpClient *http.Client, hostname string) (string, error) {
cachedClient := NewCachedClient(httpClient, time.Hour*24)
prFeatures, err := determinePullRequestFeatures(cachedClient, hostname)
if err != nil {
return "", err
}

fields := []string{
"number", "title", "state", "url", "isDraft", "isCrossRepository",
"headRefName", "headRepositoryOwner", "mergeStateStatus",
}
if prFeatures.HasStatusCheckRollup {
fields = append(fields, "statusCheckRollup")
}
if prFeatures.HasBranchProtectionRule {
fields = append(fields, "requiresStrictStatusChecks")
}

var reviewFields []string
if prFeatures.HasReviewDecision {
reviewFields = append(reviewFields, "reviewDecision", "latestReviews")
}

fragments := fmt.Sprintf(`
fragment pr on PullRequest {%s}
fragment prWithReviews on PullRequest {...pr,%s}
`, PullRequestGraphQL(fields), PullRequestGraphQL(reviewFields))
return fragments, nil
}

// CreatePullRequest creates a pull request in a GitHub repository
func CreatePullRequest(client *Client, repo *Repository, params map[string]interface{}) (*PullRequest, error) {
query := `
Expand Down
Loading

0 comments on commit 539b150

Please sign in to comment.