Skip to content

Commit

Permalink
Merge pull request #108 from prisma-idb/76-aggregate-fix
Browse files Browse the repository at this point in the history
Bugfixes
  • Loading branch information
WhyAsh5114 authored Jan 9, 2025
2 parents 14dd14e + ca34904 commit e3245fc
Show file tree
Hide file tree
Showing 11 changed files with 799 additions and 373 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,16 @@ export function addDateTimeUpdateHandler(utilsFile: SourceFile, models: readonly
.conditionalWrite(nullableDateTimeFieldPresent, ` || dateTimeUpdate === null`)
.writeLine(`)`)
.block(() => {
writer.writeLine(`(record[fieldName] as ${fieldType}) = new Date(dateTimeUpdate);`);
writer
.writeLine(`(record[fieldName] as ${fieldType}) = `)
.conditionalWrite(nullableDateTimeFieldPresent, () => `dateTimeUpdate === null ? null : `)
.write(`new Date(dateTimeUpdate);`);
});
writer.writeLine(`else if (dateTimeUpdate.set !== undefined)`).block(() => {
writer.writeLine(`(record[fieldName] as ${fieldType}) = new Date(dateTimeUpdate.set);`);
writer
.writeLine(`(record[fieldName] as ${fieldType}) = `)
.conditionalWrite(nullableDateTimeFieldPresent, () => `dateTimeUpdate.set === null ? null : `)
.write(`new Date(dateTimeUpdate.set);`);
});
},
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,25 @@ export function addAggregateMethod(modelClass: ClassDeclaration, model: Model) {
statements: (writer) => {
addTxAndRecordSetup(writer, model);
addCountHandling(writer);
addAvgHandling(writer, model);
addSumHandling(writer, model);
addMinHandling(writer, model);
addMaxHandling(writer, model);

const hasAvgOrSum = model.fields
.filter(({ isList }) => !isList)
.some((field) => field.type === "Float" || field.type === "Int" || field.type === "Decimal");
const hasMinMax =
hasAvgOrSum ||
model.fields
.filter(({ isList }) => !isList)
.some((field) => field.type === "DateTime" || field.type === "String");

if (hasMinMax) {
addMinHandling(writer, model);
addMaxHandling(writer, model);
}

if (hasAvgOrSum) {
addAvgHandling(writer, model);
addSumHandling(writer, model);
}
writer.writeLine(`return result as unknown as Prisma.Result<Prisma.${model.name}Delegate, Q, "aggregate">;`);
},
});
Expand Down Expand Up @@ -89,31 +104,139 @@ function addSumHandling(writer: CodeBlockWriter, model: Model) {
}

function addMinHandling(writer: CodeBlockWriter, model: Model) {
const nonListFields = model.fields.filter(({ isList }) => !isList);

const numericFields = nonListFields
.filter(({ isList }) => !isList)
.filter((field) => field.type === "Float" || field.type === "Int" || field.type === "Decimal")
.map((field) => field.name);
const dateTimeFields = nonListFields.filter((field) => field.type === "DateTime").map((field) => field.name);
const stringFields = nonListFields.filter((field) => field.type === "String").map((field) => field.name);
const booleanFields = nonListFields.filter((field) => field.type === "Boolean").map((field) => field.name);

writer.writeLine(`if (query?._min)`).block(() => {
writer
.writeLine(`const minResult = {} as Prisma.Result<Prisma.${model.name}Delegate, Q, "aggregate">["_min"];`)
.writeLine(`for (const untypedField of Object.keys(query._min))`)
.block(() => {
writer
.writeLine(`const field = untypedField as keyof (typeof records)[number];`)
.writeLine(`const values = records.map((record) => record[field] as number);`)
.writeLine(`(minResult[field as keyof typeof minResult] as number) = Math.min(...values);`);
})
.writeLine(`result._min = minResult;`);
writer.writeLine(`const minResult = {} as Prisma.Result<Prisma.${model.name}Delegate, Q, "aggregate">["_min"];`);
if (numericFields.length) {
writer
.writeLine(`const numericFields = ${JSON.stringify(numericFields)} as const;`)
.writeLine(`for (const field of numericFields)`)
.block(() => {
writer
.writeLine(`if (!query._min[field]) continue;`)
.writeLine(
`const values = records.map((record) => record[field] as number).filter((value) => value !== undefined);`,
)
.writeLine(`(minResult[field as keyof typeof minResult] as number) = Math.min(...values);`);
});
}
if (dateTimeFields.length) {
writer
.writeLine(`const dateTimeFields = ${JSON.stringify(dateTimeFields)} as const;`)
.writeLine(`for (const field of dateTimeFields)`)
.block(() => {
writer
.writeLine(`if (!query._min[field]) continue;`)
.writeLine(
`const values = records.map((record) => record[field]?.getTime()).filter((value) => value !== undefined);`,
)
.writeLine(`(minResult[field as keyof typeof minResult] as Date) = new Date(Math.min(...values));`);
});
}
if (stringFields.length) {
writer
.writeLine(`const stringFields = ${JSON.stringify(stringFields)} as const;`)
.writeLine(`for (const field of stringFields)`)
.block(() => {
writer
.writeLine(`if (!query._min[field]) continue;`)
.writeLine(
`const values = records.map((record) => record[field] as string).filter((value) => value !== undefined);`,
)
.writeLine(`(minResult[field as keyof typeof minResult] as string) = values.sort()[0];`);
});
}
if (booleanFields.length) {
writer
.writeLine(`const booleanFields = ${JSON.stringify(booleanFields)} as const;`)
.writeLine(`for (const field of booleanFields)`)
.block(() => {
writer
.writeLine(`if (!query._min[field]) continue;`)
.writeLine(
`const values = records.map((record) => record[field] as boolean).filter((value) => value !== undefined);`,
)
.writeLine(`(minResult[field as keyof typeof minResult] as boolean) = values.includes(true);`);
});
}
writer.writeLine(`result._min = minResult;`);
});
}

function addMaxHandling(writer: CodeBlockWriter, model: Model) {
const nonListFields = model.fields.filter(({ isList }) => !isList);

const numericFields = nonListFields
.filter((field) => field.type === "Float" || field.type === "Int" || field.type === "Decimal")
.map((field) => field.name);
const dateTimeFields = nonListFields.filter((field) => field.type === "DateTime").map((field) => field.name);
const stringFields = nonListFields.filter((field) => field.type === "String").map((field) => field.name);
const booleanFields = nonListFields.filter((field) => field.type === "Boolean").map((field) => field.name);

writer.writeLine(`if (query?._max)`).block(() => {
writer
.writeLine(`const maxResult = {} as Prisma.Result<Prisma.${model.name}Delegate, Q, "aggregate">["_max"];`)
.writeLine(`for (const untypedField of Object.keys(query._max))`)
.block(() => {
writer
.writeLine(`const field = untypedField as keyof (typeof records)[number];`)
.writeLine(`const values = records.map((record) => record[field] as number);`)
.writeLine(`(maxResult[field as keyof typeof maxResult] as number) = Math.max(...values);`);
})
.writeLine(`result._max = maxResult;`);
writer.writeLine(`const maxResult = {} as Prisma.Result<Prisma.${model.name}Delegate, Q, "aggregate">["_max"];`);
if (numericFields.length) {
writer
.writeLine(`const numericFields = ${JSON.stringify(numericFields)} as const;`)
.writeLine(`for (const field of numericFields)`)
.block(() => {
writer
.writeLine(`if (!query._max[field]) continue;`)
.writeLine(
`const values = records.map((record) => record[field] as number).filter((value) => value !== undefined);`,
)
.writeLine(`(maxResult[field as keyof typeof maxResult] as number) = Math.max(...values);`);
});
}
if (dateTimeFields.length) {
writer
.writeLine(`const dateTimeFields = ${JSON.stringify(dateTimeFields)} as const;`)
.writeLine(`for (const field of dateTimeFields)`)
.block(() => {
writer
.writeLine(`if (!query._max[field]) continue;`)
.writeLine(
`const values = records.map((record) => record[field]?.getTime()).filter((value) => value !== undefined);`,
)
.writeLine(`(maxResult[field as keyof typeof maxResult] as Date) = new Date(Math.max(...values));`);
});
}
if (stringFields.length) {
writer
.writeLine(`const stringFields = ${JSON.stringify(stringFields)} as const;`)
.writeLine(`for (const field of stringFields)`)
.block(() => {
writer
.writeLine(`if (!query._max[field]) continue;`)
.writeLine(
`const values = records.map((record) => record[field] as string).filter((value) => value !== undefined);`,
)
.writeLine(`(maxResult[field as keyof typeof maxResult] as string) = values.sort().reverse()[0];`);
});
}
if (booleanFields.length) {
writer
.writeLine(`const booleanFields = ${JSON.stringify(booleanFields)} as const;`)
.writeLine(`for (const field of booleanFields)`)
.block(() => {
writer
.writeLine(`if (!query._max[field]) continue;`)
.writeLine(
`const values = records.map((record) => record[field] as boolean).filter((value) => value !== undefined);`,
)
.writeLine(`(maxResult[field as keyof typeof maxResult] as boolean) = values.includes(true);`);
})
.writeLine(`result._max = maxResult;`);
}
writer.writeLine(`result._max = maxResult;`);
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function createDependencies(writer: CodeBlockWriter, model: Model, models: reado
const foreignKeyField = model.fields.find((fkField) => fkField.name === field.relationFromFields?.at(0))!;

writer.writeLine(`if (query.data.${field.name})`).block(() => {
addOneToOneMetaOnFieldRelation(writer, field);
addOneToOneMetaOnFieldRelation(writer, field, models);
});
handleForeignKeyValidation(writer, field, foreignKeyField, models);
});
Expand Down Expand Up @@ -79,7 +79,7 @@ function createDependents(writer: CodeBlockWriter, model: Model, models: readonl
)!;

if (!field.isList) {
addOneToOneMetaOnOtherFieldRelation(writer, field, otherField);
addOneToOneMetaOnOtherFieldRelation(writer, field, otherField, model);
} else {
const dependentModel = models.find(({ name }) => name === field.type)!;
const fks = dependentModel.fields.filter(
Expand All @@ -101,16 +101,19 @@ function applyClausesAndReturnRecords(writer: CodeBlockWriter, model: Model) {
.writeLine(`return recordsWithRelations as Prisma.Result<Prisma.${model.name}Delegate, Q, "create">;`);
}

function addOneToOneMetaOnFieldRelation(writer: CodeBlockWriter, field: Field) {
function addOneToOneMetaOnFieldRelation(writer: CodeBlockWriter, field: Field, models: readonly Model[]) {
const otherModel = models.find(({ name }) => name === field.type)!;
const otherModelKeyPath = JSON.parse(getUniqueIdentifiers(otherModel)[0].keyPath) as string[];

writer
.writeLine(`const fk: Partial<PrismaIDBSchema['${field.type}']['key']> = [];`)
.writeLine(`if (query.data.${field.name}?.create)`)
.block(() => {
writer.writeLine(
`const record = await this.client.${toCamelCase(field.type)}.create({ data: query.data.${field.name}.create }, tx);`,
);
for (let i = 0; i < field.relationToFields!.length; i++) {
writer.writeLine(`fk[${i}] = record.${field.relationToFields?.at(i)}`);
for (let i = 0; i < otherModelKeyPath!.length; i++) {
writer.writeLine(`fk[${i}] = record.${otherModelKeyPath?.at(i)}`);
}
});

Expand All @@ -120,8 +123,8 @@ function addOneToOneMetaOnFieldRelation(writer: CodeBlockWriter, field: Field) {
`const record = await this.client.${toCamelCase(field.type)}.findUniqueOrThrow({ where: query.data.${field.name}.connect }, tx);`,
)
.writeLine(`delete query.data.${field.name}.connect;`);
for (let i = 0; i < field.relationToFields!.length; i++) {
writer.writeLine(`fk[${i}] = record.${field.relationToFields?.at(i)};`);
for (let i = 0; i < otherModelKeyPath!.length; i++) {
writer.writeLine(`fk[${i}] = record.${otherModelKeyPath?.at(i)};`);
}
});

Expand All @@ -132,20 +135,26 @@ function addOneToOneMetaOnFieldRelation(writer: CodeBlockWriter, field: Field) {
.writeLine(`create: query.data.${field.name}.connectOrCreate.create,`)
.writeLine(`update: {},`)
.writeLine(`}, tx);`);
for (let i = 0; i < field.relationToFields!.length; i++) {
writer.writeLine(`fk[${i}] = record.${field.relationToFields?.at(i)};`);
for (let i = 0; i < otherModelKeyPath!.length; i++) {
writer.writeLine(`fk[${i}] = record.${otherModelKeyPath?.at(i)};`);
}
});

writer.writeLine(`const unsafeData = query.data as Record<string, unknown>;`);
field.relationFromFields!.forEach((fromField, idx) => {
writer.writeLine(`unsafeData.${fromField} = fk[${idx}];`);
writer.writeLine(`unsafeData.${fromField} = fk[${otherModelKeyPath.indexOf(field.relationToFields![idx])}];`);
});
writer.writeLine(`delete unsafeData.${field.name};`);
}

function addOneToOneMetaOnOtherFieldRelation(writer: CodeBlockWriter, field: Field, otherField: Field) {
const keyPathMapping = otherField.relationFromFields!.map((field, idx) => `${field}: keyPath[${idx}]`).join(", ");
function addOneToOneMetaOnOtherFieldRelation(writer: CodeBlockWriter, field: Field, otherField: Field, model: Model) {
const modelKeyPath = JSON.parse(getUniqueIdentifiers(model)[0].keyPath) as string[];

const keyPathMapping = otherField
.relationFromFields!.map(
(field, idx) => `${field}: keyPath[${modelKeyPath.indexOf(otherField.relationToFields![idx])}]`,
)
.join(", ");

writer.writeLine(`if (query.data.${field.name}?.create)`).block(() => {
writer
Expand All @@ -159,7 +168,7 @@ function addOneToOneMetaOnOtherFieldRelation(writer: CodeBlockWriter, field: Fie
});
writer.writeLine(`if (query.data.${field.name}?.connect)`).block(() => {
writer.writeLine(
`await this.client.${toCamelCase(field.type)}.update({ where: query.data.${field.name}.connect, data: { ${otherField.relationFromFields?.at(0)}: keyPath[0] } }, tx);`,
`await this.client.${toCamelCase(field.type)}.update({ where: query.data.${field.name}.connect, data: { ${keyPathMapping} } }, tx);`,
);
});
writer.writeLine(`if (query.data.${field.name}?.connectOrCreate)`).block(() => {
Expand Down Expand Up @@ -188,14 +197,19 @@ function addOneToManyRelation(

const modelPk = getUniqueIdentifiers(model)[0];
const modelPkFields = JSON.parse(modelPk.keyPath) as string[];
const fields = `{ ${otherField.relationToFields!.map((field, idx) => `${field}: keyPath[${idx}]`).join(", ")} }`;
const fields = `{ ${modelPkFields.map((field, idx) => `${field}: keyPath[${idx}]`).join(", ")} }`;

let nestedConnectLine = `${otherField.name}: { connect: `;
if (modelPkFields.length === 1) nestedConnectLine += `${fields}`;
else nestedConnectLine += `{ ${modelPk.name}: ${fields} }`;
nestedConnectLine += ` }`;

const nestedDirectLine = otherField.relationFromFields!.map((field, idx) => `${field}: keyPath[${idx}]`).join(", ");
const nestedDirectLine = otherField
.relationFromFields!.map(
(field, idx) =>
`${field}: keyPath[${JSON.parse(getUniqueIdentifiers(model)[0].keyPath).indexOf(otherField.relationToFields?.at(idx))}]`,
)
.join(", ");
const connectQuery = getCreateQuery(nestedConnectLine);

writer.writeLine(`if (query.data?.${field.name}?.create)`).block(() => {
Expand Down
Loading

0 comments on commit e3245fc

Please sign in to comment.