Skip to content

Commit

Permalink
feat: Use OAuth flow to generate R2 tokens for Pipelines
Browse files Browse the repository at this point in the history
This commit changes the generateR2Tokens flow which will direct the user
to the web browser to perform a OAuth flow to grant the Workers
Pipelines client the ability to generate R2 tokens on behalf of the
user. This will only run if the user does not provide the credentials as
CLI parameters.

Due to requiring user interactivity, and reliance on the callbacks,
there is no easy way to support a "headless" mode for `wrangler pipelines
create` (or `update`) unless the user provides the tokens as arguments.
The same applies for testing this flow, which can only be done manually
at this time.
  • Loading branch information
cmackenzie1 committed Dec 16, 2024
1 parent 77bd9a1 commit c11f0f4
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 82 deletions.
23 changes: 4 additions & 19 deletions packages/wrangler/src/__tests__/pipelines.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { http, HttpResponse } from "msw";
import { normalizeOutput } from "../../e2e/helpers/normalize";
import { __testSkipDelays } from "../pipelines";
import { endEventLoop } from "./helpers/end-event-loop";
import { mockAccountId, mockApiToken } from "./helpers/mock-account-id";
import { mockConsoleMethods } from "./helpers/mock-console";
Expand Down Expand Up @@ -115,7 +114,7 @@ describe("pipelines", () => {
return requests;
}

function mockCreeatR2TokenFailure(bucket: string) {
function mockCreateR2TokenFailure(bucket: string) {
const requests = { count: 0 };
msw.use(
http.get(
Expand Down Expand Up @@ -310,9 +309,6 @@ describe("pipelines", () => {
);
return requests;
}
beforeAll(() => {
__testSkipDelays();
});

it("shows usage details", async () => {
await runWrangler("pipelines");
Expand Down Expand Up @@ -383,15 +379,6 @@ describe("pipelines", () => {
`);
});

it("should create a pipeline", async () => {
const tokenReq = mockCreateR2Token("test-bucket");
const requests = mockCreateRequest("my-pipeline");
await runWrangler("pipelines create my-pipeline --r2 test-bucket");

expect(tokenReq.count).toEqual(3);
expect(requests.count).toEqual(1);
});

it("should create a pipeline with explicit credentials", async () => {
const requests = mockCreateRequest("my-pipeline");
await runWrangler(
Expand All @@ -401,7 +388,7 @@ describe("pipelines", () => {
});

it("should fail a missing bucket", async () => {
const requests = mockCreeatR2TokenFailure("bad-bucket");
const requests = mockCreateR2TokenFailure("bad-bucket");
await expect(
runWrangler("pipelines create bad-pipeline --r2 bad-bucket")
).rejects.toThrowError();
Expand Down Expand Up @@ -543,7 +530,6 @@ describe("pipelines", () => {

it("should update a pipeline with new bucket", async () => {
const pipeline: Pipeline = samplePipeline;
const tokenReq = mockCreateR2Token("new-bucket");
mockShowRequest(pipeline.name, pipeline);

const update = JSON.parse(JSON.stringify(pipeline));
Expand All @@ -552,13 +538,12 @@ describe("pipelines", () => {
endpoint: "https://some-account-id.r2.cloudflarestorage.com",
access_key_id: "service-token-id",
secret_access_key:
"be22cbae9c1585c7b61a92fdb75afd10babd535fb9b317f90ac9a9ca896d02d7",
"my-secret-access-key",
};
const updateReq = mockUpdateRequest(update.name, update);

await runWrangler("pipelines update my-pipeline --r2 new-bucket");
await runWrangler("pipelines update my-pipeline --r2 new-bucket --access-key-id service-token-id --secret-access-key my-secret-access-key");

expect(tokenReq.count).toEqual(3);
expect(updateReq.count).toEqual(1);
});

Expand Down
110 changes: 78 additions & 32 deletions packages/wrangler/src/pipelines/client.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import assert from "node:assert";
import { createHash } from "node:crypto";
import http from "node:http";
import { fetchResult } from "../cfetch";
import { getCloudflareApiEnvironmentFromEnv } from "../environment-variables/misc-variables";
import { UserError } from "../errors";
import { logger } from "../logger";
import openInBrowser from "../open-in-browser";
import type { R2BucketInfo } from "../r2/helpers";

// ensure this is in sync with:
Expand Down Expand Up @@ -96,44 +102,84 @@ export type PermissionGroup = {
scopes: string[];
};

interface S3AccessKey {
accessKeyId: string;
secretAccessKey: string;
}

// Generate a Service Token to write to R2 for a pipeline
export async function generateR2ServiceToken(
label: string,
accountId: string,
bucket: string
): Promise<ServiceToken> {
const res = await fetchResult<PermissionGroup[]>(
`/user/tokens/permission_groups`,
{
method: "GET",
}
);
const perm = res.find(
(g) => g.name == "Workers R2 Storage Bucket Item Write"
);
if (!perm) {
throw new Error("Missing R2 Permissions");
}

// generate specific bucket write token for pipeline
const body = JSON.stringify({
policies: [
{
effect: "allow",
permission_groups: [{ id: perm.id }],
resources: {
[`com.cloudflare.edge.r2.bucket.${accountId}_default_${bucket}`]: "*",
},
},
],
name: label,
bucketName: string,
pipelineName: string
): Promise<S3AccessKey> {
let server: http.Server;
let loginTimeoutHandle: ReturnType<typeof setTimeout>;
const timerPromise = new Promise<S3AccessKey>((_, reject) => {
loginTimeoutHandle = setTimeout(() => {
server.close();
clearTimeout(loginTimeoutHandle);
reject(
new UserError(
"Timed out waiting for authorization code, please try again."
)
);
}, 120000); // wait for 120 seconds for the user to authorize
});

return await fetchResult<ServiceToken>(`/user/tokens`, {
method: "POST",
headers: API_HEADERS,
body,
const loginPromise = new Promise<S3AccessKey>((resolve, reject) => {
server = http.createServer(async (request, response) => {
assert(request.url, "This request doesn't have a URL"); // This should never happen

if (request.method !== "GET") {
response.writeHead(405);
response.end("Method not allowed.");
return;
}

const { pathname, searchParams } = new URL(
request.url,
`http://${request.headers.host}`
);

if (pathname !== "/") {
response.writeHead(404);
response.end("Not found.");
return;
}

// Retrieve values from the URL parameters
const accessKeyId = searchParams.get("access-key-id");
const secretAccessKey = searchParams.get("secret-access-key");

if (!accessKeyId || !secretAccessKey) {
reject(new UserError("Missing required URL parameters"));
return;
}

resolve({ accessKeyId, secretAccessKey } as S3AccessKey);
// Do a final redirect to "clear" the URL of the sensitive URL parameters that were returned.
response.writeHead(307, {
Location:
"https://welcome.developers.workers.dev/wrangler-oauth-consent-granted",
});
response.end();
});

server.listen(8976, "localhost");
});

const env = getCloudflareApiEnvironmentFromEnv();
const oauthDomain =
env === "staging"
? "oauth.pipelines-staging.cloudflare.com"
: "oauth.pipelines.cloudflare.com";

const urlToOpen = `https://${oauthDomain}/oauth/login?accountId=${accountId}&bucketName=${bucketName}&pipelineName=${pipelineName}`;
logger.log(`Opening a link in your default browser: ${urlToOpen}`);
await openInBrowser(urlToOpen);

return Promise.race([timerPromise, loginPromise]);
}

// Get R2 bucket information from v4 API
Expand Down
44 changes: 13 additions & 31 deletions packages/wrangler/src/pipelines/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { readConfig } from "../config";
import { sleep } from "../deploy/deploy";
import { FatalError, UserError } from "../errors";
import { printWranglerBanner } from "../index";
import { logger } from "../logger";
Expand All @@ -13,7 +12,6 @@ import {
getPipeline,
getR2Bucket,
listPipelines,
sha256,
updatePipeline,
} from "./client";
import type { CommonYargsArgv, CommonYargsOptions } from "../yargs-types";
Expand All @@ -25,42 +23,31 @@ import type {
} from "./client";
import type { Argv } from "yargs";

// flag to skip delays for tests
let __testSkipDelaysFlag = false;

async function authorizeR2Bucket(
name: string,
pipelineName: string,
accountId: string,
bucket: string
bucketName: string
) {
try {
await getR2Bucket(accountId, bucket);
await getR2Bucket(accountId, bucketName);
} catch (err) {
if (err instanceof APIError) {
if (err.code == 10006) {
throw new FatalError(`The R2 bucket [${bucket}] doesn't exist`);
throw new FatalError(`The R2 bucket [${bucketName}] doesn't exist`);
}
}
throw err;
}

logger.log(`🌀 Authorizing R2 bucket "${bucket}"`);
logger.log(`🌀 Authorizing R2 bucket "${bucketName}"`);

const serviceToken = await generateR2ServiceToken(
`Service token for Pipeline ${name}`,
accountId,
bucket
bucketName,
pipelineName
);
const access_key_id = serviceToken.id;
const secret_access_key = sha256(serviceToken.value);

// wait for token to settle/propagate
!__testSkipDelaysFlag && (await sleep(3000));

return {
secret_access_key,
access_key_id,
};
return serviceToken;
}

function getAccountR2Endpoint(accountId: string) {
Expand Down Expand Up @@ -240,8 +227,8 @@ export function pipelines(pipelineYargs: CommonYargsArgv) {
accountId,
pipelineConfig.destination.path.bucket
);
destination.credentials.access_key_id = auth.access_key_id;
destination.credentials.secret_access_key = auth.secret_access_key;
destination.credentials.access_key_id = auth.accessKeyId;
destination.credentials.secret_access_key = auth.secretAccessKey;
}

if (!destination.credentials.access_key_id) {
Expand Down Expand Up @@ -415,8 +402,8 @@ export function pipelines(pipelineYargs: CommonYargsArgv) {
accountId,
destination.path.bucket
);
destination.credentials.access_key_id = auth.access_key_id;
destination.credentials.secret_access_key = auth.secret_access_key;
destination.credentials.access_key_id = auth.accessKeyId;
destination.credentials.secret_access_key = auth.secretAccessKey;
}
if (!destination.credentials.access_key_id) {
throw new FatalError("Requires a r2 access key id");
Expand Down Expand Up @@ -463,7 +450,7 @@ export function pipelines(pipelineYargs: CommonYargsArgv) {
args.authentication !== undefined
? // if auth specified, use it
args.authentication
: // if auth not specified, use previos value or default(false)
: // if auth not specified, use previous value or default(false)
source?.authentication,
} satisfies HttpSource);
}
Expand Down Expand Up @@ -521,8 +508,3 @@ export function pipelines(pipelineYargs: CommonYargsArgv) {
}
);
}

// Test exception to remove delays
export function __testSkipDelays() {
__testSkipDelaysFlag = true;
}

0 comments on commit c11f0f4

Please sign in to comment.