Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ feat: Add custom fields & role assignment to OpenID strategy #5798

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
5 changes: 4 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,9 @@ OPENID_NAME_CLAIM=

OPENID_BUTTON_LABEL=
OPENID_IMAGE_URL=
OPENID_CUSTOM_DATA=
OPENID_PROVIDER=
OPENID_ADMIN_ROLE=

# LDAP
LDAP_URL=
Expand Down Expand Up @@ -527,4 +530,4 @@ HELP_AND_FAQ_URL=https://librechat.ai
#=====================================================#
# OpenWeather #
#=====================================================#
OPENWEATHER_API_KEY=
OPENWEATHER_API_KEY=
6 changes: 6 additions & 0 deletions api/models/schema/userSchema.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const { SystemRoles } = require('librechat-data-provider');
* @property {string} [googleId] - Optional Google ID for the user
* @property {string} [facebookId] - Optional Facebook ID for the user
* @property {string} [openidId] - Optional OpenID ID for the user
* @property {Map<string, string>} [customOpenIdData] - A map containing provider-specific custom data retrieved via OpenID Connect.
* @property {string} [ldapId] - Optional LDAP ID for the user
* @property {string} [githubId] - Optional GitHub ID for the user
* @property {string} [discordId] - Optional Discord ID for the user
Expand Down Expand Up @@ -97,6 +98,11 @@ const userSchema = mongoose.Schema(
unique: true,
sparse: true,
},
customOpenIdData: {
type: Map,
of: mongoose.Schema.Types.Mixed,
default: {},
},
ldapId: {
type: String,
unique: true,
Expand Down
5 changes: 3 additions & 2 deletions api/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"@langchain/google-vertexai": "^0.1.8",
"@langchain/textsplitters": "^0.1.0",
"@librechat/agents": "^2.0.4",
"@microsoft/microsoft-graph-client": "^3.0.7",
"@waylaidwanderer/fetch-event-source": "^3.0.1",
"axios": "1.7.8",
"bcryptjs": "^2.4.3",
Expand Down Expand Up @@ -86,8 +87,8 @@
"ollama": "^0.5.0",
"openai": "^4.47.1",
"openai-chat-tokens": "^0.2.8",
"openid-client": "^5.4.2",
"passport": "^0.6.0",
"openid-client": "^5.7.1",
"passport": "^0.7.0",
"passport-apple": "^2.0.2",
"passport-discord": "^0.1.4",
"passport-facebook": "^3.0.0",
Expand Down
167 changes: 167 additions & 0 deletions api/strategies/OpenId/openidDataMapper.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
const fetch = require('node-fetch');
const { HttpsProxyAgent } = require('https-proxy-agent');
const { logger } = require('~/config');
const { URL } = require('url');

// Microsoft SDK
const { Client: MicrosoftGraphClient } = require('@microsoft/microsoft-graph-client');

/**
* Base class for provider-specific data mappers.
*/
class BaseDataMapper {
/**
* Map custom OpenID data.
* @param {string} accessToken - The access token to authenticate the request.
* @param {string|Array<string>} customQuery - Either a full query string (if it contains operators)
* or an array of fields to select.
* @returns {Promise<Map<string, any>>} A promise that resolves to a map of custom fields.
* @throws {Error} Throws an error if not implemented in the subclass.
*/
async mapCustomData(accessToken, customQuery) {
throw new Error('mapCustomData() must be implemented by subclasses');
}

/**
* Optionally handle proxy settings for HTTP requests.
* @returns {Object} Configuration object with proxy settings if PROXY is set.
*/
getProxyOptions() {
if (process.env.PROXY) {
const agent = new HttpsProxyAgent(process.env.PROXY);
return { agent };
}
return {};
}
}

/**
* Microsoft-specific data mapper using the Microsoft Graph SDK.
*/
class MicrosoftDataMapper extends BaseDataMapper {
/**
* Initializes the MicrosoftGraphClient once for reuse.
*/
constructor() {
super();
this.accessToken = null;

this.client = MicrosoftGraphClient.init({
defaultVersion: 'beta',
authProvider: (done) => {
// The authProvider will be called for each request to get the token
if (this.accessToken) {
done(null, this.accessToken);
} else {
done(new Error('Access token is not set.'), null);
}
},
fetch: fetch,
...this.getProxyOptions(),
});

// Bind methods to maintain context
this.mapCustomData = this.mapCustomData.bind(this);
this.cleanData = this.cleanData.bind(this);
}

/**
* Set the access token for the client.
* This method should be called before making any requests.
*
* @param {string} accessToken - The access token.
*/
setAccessToken(accessToken) {
if (!accessToken || typeof accessToken !== 'string') {
throw new Error('[MicrosoftDataMapper] Invalid access token provided.');
}
this.accessToken = accessToken;
}

/**
* Map custom OpenID data using the Microsoft Graph SDK.
*
* @param {string} accessToken - The access token to authenticate the request.
* @param {string|Array<string>} customQuery - Fields to select from the Microsoft Graph API.
* @returns {Promise<Map<string, any>>} A promise that resolves to a map of custom fields.
*/
async mapCustomData(accessToken, customQuery) {
try {
this.setAccessToken(accessToken);

if (!customQuery) {
logger.warn('[MicrosoftDataMapper] No customQuery provided.');
return new Map();
}

// Convert customQuery to a comma-separated string if it's an array
const fields = Array.isArray(customQuery) ? customQuery.join(',') : customQuery;

if (!fields) {
logger.warn('[MicrosoftDataMapper] No fields specified in customQuery.');
return new Map();
}

const result = await this.client
.api('/me')
.select(fields)
.get();

const cleanedData = this.cleanData(result);
return new Map(Object.entries(cleanedData));
} catch (error) {
// Handle specific Microsoft Graph errors if needed
logger.error(`[MicrosoftDataMapper] Error fetching user data: ${error.message}`, { stack: error.stack });
return new Map();
}
}

/**
* Recursively remove all keys starting with @odata. from an object and convert Maps.
*
* @param {object|Array} obj - The object or array to clean.
* @returns {object|Array} - The cleaned object or array.
*/
cleanData(obj) {
if (Array.isArray(obj)) {
return obj.map(this.cleanData);
} else if (obj && typeof obj === 'object') {
return Object.entries(obj).reduce((acc, [key, value]) => {
if (!key.startsWith('@odata.')) {
acc[key] = this.cleanData(value);
}
return acc;
}, {});
}
return obj;
}
}

/**
* Map provider names to their specific data mappers.
*/
const PROVIDER_MAPPERS = {
microsoft: MicrosoftDataMapper,
};

/**
* Abstraction layer that returns a provider-specific mapper instance.
*/
class OpenIdDataMapper {
/**
* Retrieve an instance of the mapper for the specified provider.
*
* @param {string} provider - The name of the provider (e.g., 'microsoft').
* @returns {BaseDataMapper} An instance of the specific data mapper for the provider.
* @throws {Error} Throws an error if no mapper is found for the specified provider.
*/
static getMapper(provider) {
const MapperClass = PROVIDER_MAPPERS[provider.toLowerCase()];
if (!MapperClass) {
throw new Error(`No mapper found for provider: ${provider}`);
}
return new MapperClass();
}
}

module.exports = OpenIdDataMapper;
Loading