Skip to content

Commit

Permalink
Fix #557
Browse files Browse the repository at this point in the history
  • Loading branch information
alex999990009 authored and sxhya committed Nov 12, 2024
1 parent 27cc1ba commit be47dc9
Show file tree
Hide file tree
Showing 5 changed files with 287 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class CoClauseInserter(private val coClause: CoClauseBase) : AbstractCoClauseIns
anchor.parent.addAfter(sampleCoClause, anchor)
moveCaretToEndOffset(editor, anchor.nextSibling)

anchor.parent.addAfter(factory.createWhitespace("\n"), anchor)
anchor.parent.addAfter(factory.createWhitespace("\n "), anchor)
}
}

Expand All @@ -58,7 +58,7 @@ open class ArendFunctionalInserter(private val definition: ArendFunctionDefiniti
val samplePipe = sampleCoClause.findPrevSibling()!!
val anchor = body.coClauseList.lastOrNull() ?: body.lbrace ?: body.cowithKw
val insertedClause = if (anchor != null) body.addAfter(sampleCoClause, anchor) else body.add(sampleCoClause)
body.addBefore(factory.createWhitespace("\n"), insertedClause)
body.addBefore(factory.createWhitespace("\n "), insertedClause)
body.addBefore(samplePipe, insertedClause)

if (insertedClause != null) moveCaretToEndOffset(editor, insertedClause)
Expand All @@ -72,7 +72,7 @@ class FunctionDefinitionInserter(private val functionDefinition: ArendDefFunctio
if (functionBody == null) {
val functionBodySample = factory.createCoClauseInFunction(name).parent as ArendFunctionBody
functionBody = functionDefinition.addAfter(functionBodySample, functionDefinition.children.last()) as ArendFunctionBody
functionDefinition.addBefore(factory.createWhitespace("\n"), functionBody)
functionDefinition.addBefore(factory.createWhitespace("\n "), functionBody)
moveCaretToEndOffset(editor, functionBody.lastChild)
} else super.insertFirstCoClause(name, factory, editor)
}
Expand All @@ -85,7 +85,7 @@ class ArendInstanceInserter(private val instance: ArendDefInstance) : ArendFunct
val instanceBodySample = factory.createCoClause(name).parent as ArendFunctionBody
val anchor = instance.returnExpr ?: instance.defIdentifier
instanceBody = instance.addAfter(instanceBodySample, anchor) as ArendFunctionBody
instance.addBefore(factory.createWhitespace("\n"), instanceBody)
instance.addBefore(factory.createWhitespace("\n "), instanceBody)
moveCaretToEndOffset(editor, instanceBody.lastChild)
} else super.insertFirstCoClause(name, factory, editor)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
package org.arend.quickfix.implementCoClause

import com.google.common.collect.Sets.combinations
import com.intellij.codeInsight.intention.IntentionAction
import com.intellij.icons.AllIcons
import com.intellij.openapi.command.WriteCommandAction
import com.intellij.openapi.components.service
import com.intellij.openapi.editor.Editor
import com.intellij.openapi.project.Project
import com.intellij.openapi.ui.popup.JBPopupFactory
import com.intellij.openapi.ui.popup.PopupStep
import com.intellij.openapi.ui.popup.util.BaseListPopupStep
import com.intellij.openapi.util.Iconable
import com.intellij.psi.PsiElement
import com.intellij.psi.PsiFile
import com.intellij.psi.SmartPsiElementPointer
import org.arend.naming.reference.LocatedReferable
import org.arend.naming.reference.Referable
import org.arend.psi.ArendFile
import org.arend.psi.ArendPsiFactory
import org.arend.psi.ext.*
import org.arend.psi.findPrevSibling
import org.arend.refactoring.moveCaretToEndOffset
import org.arend.settings.ArendProjectStatistics
import org.arend.term.abs.Abstract
import org.arend.util.ArendBundle
import kotlin.math.min

open class ImplementFieldsQuickFix(private val instanceRef: SmartPsiElementPointer<PsiElement>,
private val needsBulb: Boolean,
Expand All @@ -30,6 +40,73 @@ open class ImplementFieldsQuickFix(private val instanceRef: SmartPsiElementPoint
override fun isAvailable(project: Project, editor: Editor?, file: PsiFile?): Boolean =
instanceRef.element != null

private fun collectDefaultStatements(defClass: ArendDefClass, defaultFields: MutableSet<ArendClassStat>) {
defaultFields.addAll(defClass.classStatList.filter { it.isDefault })
for (superClass in defClass.superClassReferences) {
(superClass as? ArendDefClass?)?.let { collectDefaultStatements(it, defaultFields) }
}
}

private fun getMinGroup(defaultFields: Set<PsiElement>, rules: Map<PsiElement, List<PsiElement>>): Set<Set<PsiElement>> {
val results = mutableSetOf<Set<PsiElement>>()
for (groupSize in 1..min(DEFAULT_FIELDS_LIMIT, defaultFields.size)) {
if (results.isNotEmpty()) {
break
}
val groups = combinations(defaultFields.toSet(), groupSize)
for (group in groups) {
val derivedFields = group.toMutableSet()
while (derivedFields.size < defaultFields.size) {
var added = false
for ((field, arguments) in rules) {
if (!derivedFields.contains(field) && arguments.all { derivedFields.contains(it) }) {
derivedFields.add(field)
added = true
}
}
if (!added) {
break
}
}
if (derivedFields.size == defaultFields.size) {
results.add(group)
}
}
}
return results
}

private fun getMinDefaultFields(): Pair<Set<Set<PsiElement>>, Set<PsiElement>> {
val defaultStatements = mutableSetOf<ArendClassStat>()
((instanceRef.element as Abstract.ClassReferenceHolder).classReference as? ArendDefClass?)?.let {
collectDefaultStatements(it, defaultStatements)
}
val defaultFields = defaultStatements.mapNotNull { it.coClause?.longName?.resolve }.toSet()

val rules = mutableMapOf<PsiElement, MutableList<PsiElement>>()
for (defaultStatement in defaultStatements) {
val defaultField = defaultStatement.coClause?.longName?.resolve ?: continue
val arguments = (defaultStatement.coClause?.expr as? ArendNewExpr?)?.argumentAppExpr?.argumentList
?.mapNotNull { (it as? ArendAtomArgument?)?.atomFieldsAcc }?.toMutableList() ?: mutableListOf()
(defaultStatement.coClause?.expr as? ArendNewExpr?)?.argumentAppExpr?.atomFieldsAcc?.let { arguments.add(it) }
for (argument in arguments) {
val defaultArgument = argument.atom.literal?.longName?.resolve ?: continue
if (!defaultFields.contains(defaultArgument)) {
continue
}
rules.getOrPut(defaultField) { mutableListOf() }.add(defaultArgument)
}
}

val defaultDependentFields = defaultFields.filter { rules[it] != null }.toSet()
val result = if (rules.isEmpty()) {
emptySet()
} else {
getMinGroup(defaultDependentFields, rules)
}
return Pair(result, defaultFields)
}

private fun addField(field: Referable, inserter: AbstractCoClauseInserter, editor: Editor?, psiFactory: ArendPsiFactory, needQualifiedName: Boolean = false) {
val coClauses = inserter.coClausesList
val fieldClass = (field as? LocatedReferable)?.locatedReferableParent
Expand All @@ -48,25 +125,169 @@ open class ImplementFieldsQuickFix(private val instanceRef: SmartPsiElementPoint

if (coClause != null) {
val pipeSample = coClause.findPrevSibling()
val whitespace = psiFactory.createWhitespace(" ")
val insertedCoClause = anchor.parent.addAfter(coClause, anchor)
if (insertedCoClause is ArendCoClause && pipeSample != null) {
anchor.parent.addBefore(pipeSample, insertedCoClause)
anchor.parent.addBefore(whitespace, insertedCoClause)
}
if (!caretMoved && editor != null) {
moveCaretToEndOffset(editor, anchor.nextSibling)
caretMoved = true
}
anchor.parent.addAfter(psiFactory.createWhitespace("\n"), anchor)
anchor.parent.addAfter(psiFactory.createWhitespace("\n "), anchor)
}
}
}

private fun getFullClassName(): String {
val classReferable = (instanceRef.element as ArendDefInstance).classReference as PsiLocatedReferable
val name = (classReferable.containingFile as ArendFile).moduleLocation.toString() + "." + classReferable.fullName
return name
}

private fun showFields(
project: Project,
editor: Editor,
variants: MutableList<MutableList<LocatedReferable>>,
allFields: Map<LocatedReferable, Boolean>,
baseFields: List<Pair<LocatedReferable, Boolean>>
) {
if (variants.isEmpty()) {
variants.add(baseFields.map { it.first }.toMutableList())
} else if (variants.size == 1) {
variants.first().addAll(0, baseFields.map { it.first })
}

if (variants.size == 1) {
val psiFactory = ArendPsiFactory(project)
val firstCCInserter = makeFirstCoClauseInserter(instanceRef.element) ?: return
WriteCommandAction.runWriteCommandAction(editor.project) {
for (field in variants.first()) {
addField(field, firstCCInserter, editor, psiFactory, allFields[field]!!)
}
}
return
}

val className = getFullClassName()
val defaultArguments = project.service<ArendProjectStatistics>().state.implementFieldsStatistics[className]
var suggestDefaultOption = true
var matchedList: List<LocatedReferable>? = null
if (defaultArguments != null && variants[0].size == defaultArguments.size) {
val sortedDefaultArguments = defaultArguments.sorted()
for (variant in variants) {
val sortedVariant = variant.map { it.textRepresentation() }.sorted()
if (sortedVariant == sortedDefaultArguments) {
matchedList = variant
}
}
if (matchedList == null) {
suggestDefaultOption = false
}
} else {
suggestDefaultOption = false
}

if (suggestDefaultOption) {
val defaultOption = ArendBundle.message("arend.clause.implementMissing.default.option")
val anotherOption = ArendBundle.message("arend.clause.implementMissing.another.option")
val defaultOptionStep = object : BaseListPopupStep<String>(ArendBundle.message("arend.clause.implementMissing.question"), listOf(defaultOption, anotherOption)) {
override fun onChosen(option: String?, finalChoice: Boolean): PopupStep<*>? {
if (option == defaultOption) {
printFields(project, editor, baseFields, matchedList!!, allFields)
} else if (option == anotherOption) {
createListOfVariants(project, editor, allFields, baseFields, variants)
}
return FINAL_CHOICE
}
}

val popup = JBPopupFactory.getInstance().createListPopup(defaultOptionStep)
popup.showInBestPositionFor(editor)
} else {
createListOfVariants(project, editor, allFields, baseFields, variants)
}
}

private fun createListOfVariants(
project: Project,
editor: Editor,
allFields: Map<LocatedReferable, Boolean>,
baseFields: List<Pair<LocatedReferable, Boolean>>,
variants: List<List<LocatedReferable>>
) {
val fieldsToImplementStep = object : BaseListPopupStep<List<LocatedReferable>>(ArendBundle.message("arend.clause.implementMissing.question"), variants) {
override fun onChosen(extraFields: List<LocatedReferable>?, finalChoice: Boolean): PopupStep<*>? {
if (extraFields != null) {
val name = getFullClassName()
project.service<ArendProjectStatistics>().state.implementFieldsStatistics[name] = extraFields.map { it.textRepresentation() }
printFields(project, editor, baseFields, extraFields, allFields)
}
return FINAL_CHOICE
}
}

val popup = JBPopupFactory.getInstance().createListPopup(fieldsToImplementStep)
popup.showInBestPositionFor(editor)
}

private fun printFields(
project: Project,
editor: Editor,
baseFields: List<Pair<LocatedReferable, Boolean>>,
extraFields: List<LocatedReferable>,
allFields: Map<LocatedReferable, Boolean>
) {
val psiFactory = ArendPsiFactory(project)
val firstCCInserter = makeFirstCoClauseInserter(instanceRef.element) ?: return
val fields = baseFields.map { it.first } + extraFields
WriteCommandAction.runWriteCommandAction(editor.project) {
for (field in fields) {
addField(field, firstCCInserter, editor, psiFactory, allFields[field]!!)
}
}
}

override fun invoke(project: Project, editor: Editor?, file: PsiFile?) {
editor ?: return
val instance = instanceRef.element ?: return
val psiFactory = ArendPsiFactory(project)
val firstCCInserter = makeFirstCoClauseInserter(instance) ?: return
for (f in fieldsToImplement) addField(f.first, firstCCInserter, editor, psiFactory, f.second)

val (groups, defaultFields) = if (instance is Abstract.ClassReferenceHolder) {
getMinDefaultFields()
} else {
Pair(emptySet(), emptySet())
}

val allFields = fieldsToImplement.toMap()
val baseFields = fieldsToImplement.filter { !defaultFields.contains(it.first.underlyingReferable as? PsiElement?) }
val extraFields = fieldsToImplement.filter { defaultFields.contains(it.first.underlyingReferable as? PsiElement?) }
.associateBy { (referable, _) -> defaultFields.find { referable.underlyingReferable == it }!! }

var variants = mutableListOf<MutableList<LocatedReferable>>()
for (group in groups) {
val variant = mutableListOf<LocatedReferable>()
for (element in group) {
extraFields[element]?.first?.let { variant.add(it) }
}
variants.add(variant)
}

if (variants.size >= 1) {
val minSize = variants.minBy { it.size }.size
variants = if (minSize == 0) {
mutableListOf()
} else {
variants.filter { it.size == minSize }.toMutableList()
}
}

showFields(project, editor, variants, allFields, baseFields)
}

override fun getIcon(flags: Int) = if (needsBulb) AllIcons.Actions.IntentionBulb else null

companion object {
const val DEFAULT_FIELDS_LIMIT = 16
}
}
11 changes: 11 additions & 0 deletions src/main/kotlin/org/arend/settings/ArendProjectStatistics.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package org.arend.settings

import com.intellij.openapi.components.*

@Service(Service.Level.PROJECT)
@State(name = "ArendStatistics", storages = [Storage(StoragePathMacros.WORKSPACE_FILE)])
class ArendProjectStatistics : SimplePersistentStateComponent<ArendProjectStatisticsState>(ArendProjectStatisticsState())

class ArendProjectStatisticsState : BaseState() {
var implementFieldsStatistics by map<String, List<String>>()
}
3 changes: 3 additions & 0 deletions src/main/resources/messages/ArendBundle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ arend.clause.removeRedundant=Remove redundant clause
arend.clause.removeRedundantRHS=Remove redundant clause's right-hand side
arend.coClause.implementMissing=Implement missing fields
arend.clause.implementMissing.default.option=Select default fields
arend.clause.implementMissing.another.option=Select a list of possible field options
arend.clause.implementMissing.question=Choose Which Default Fields To Implement
arend.coClause.implementParentFields=Implement fields of {0}
arend.coClause.removeRedundant=Remove redundant coclause
arend.coClause.replaceWithEmptyImplementation=Replace {?} with empty implementation of the class
Expand Down
44 changes: 44 additions & 0 deletions src/test/kotlin/org/arend/quickfix/ImplementFieldsQuickFixTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -287,4 +287,48 @@ class ImplementFieldsQuickFixTest : QuickFixTestBase() {
\record R (x : Nat) \record S (r : Nat -> R) \func foo : S \cowith | r n : R \cowith
| x => {?}{-caret-}
""")

fun `test default fields instance`() = simpleQuickFixTest(implement,
"""
\class C {
| field : Nat -> Nat
| defaultField : Nat -> Nat
\default defaultField => field
}
\instance I{-caret-} : C
""", """
\class C {
| field : Nat -> Nat
| defaultField : Nat -> Nat
\default defaultField => field
}
\instance I : C
| field => {?}
""")

fun `test default fields func`() = simpleQuickFixTest(implement,
"""
\class C {
| field : Nat -> Nat
| defaultField : Nat -> Nat
\default defaultField => field
}
\func {-caret-}F : C \cowith
""", """
\class C {
| field : Nat -> Nat
| defaultField : Nat -> Nat
\default defaultField => field
}
\func F : C \cowith
| field => {?}
""")
}

0 comments on commit be47dc9

Please sign in to comment.