diff --git a/README.md b/README.md index a1a04e4..035069f 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ enum class PretrainedModels { } ``` ##### GenerationLevel -No pretrained models exist for `CHAR` currently. +No pretrained models exist for `CHAR` currently. `CHAR` is not supported either. ```kotlin enum class GenerationLevel { WORD, diff --git a/src/main/kotlin/com/londogard/textgen/LanguageModel.kt b/src/main/kotlin/com/londogard/textgen/LanguageModel.kt index 1957817..b3cdef3 100644 --- a/src/main/kotlin/com/londogard/textgen/LanguageModel.kt +++ b/src/main/kotlin/com/londogard/textgen/LanguageModel.kt @@ -27,5 +27,8 @@ interface LanguageModel { * @param oneDocumentPerLine if true we'll treat each line as a document (i.e. pad around each line). If false it'll load the full file. */ fun createCustomModel(path: String, name: String, oneDocumentPerLine: Boolean = false) + + fun changeModelToCustom(path: String) + fun changeModelToPretrained(pretrainedModel: PretrainedModels) } diff --git a/src/main/kotlin/com/londogard/textgen/LanguageModelImpl.kt b/src/main/kotlin/com/londogard/textgen/LanguageModelImpl.kt index 34c5d39..8910b26 100644 --- a/src/main/kotlin/com/londogard/textgen/LanguageModelImpl.kt +++ b/src/main/kotlin/com/londogard/textgen/LanguageModelImpl.kt @@ -11,19 +11,19 @@ class LanguageModelImpl( override val generationLevel: GenerationLevel, override val n: Int = 10 ) : LanguageModel { - private val model: BackendLM private val logger = logger().value + private val model: BackendLM init { model = NGramWordLM(n) - when (pretrainedModels) { - PretrainedModels.CUSTOM -> Unit - else -> model.loadModel(getResourcePath(pretrainedModels.path)) - } when (generationLevel) { GenerationLevel.CHAR -> throw NotImplementedError("Word Generation is not implemented yet.") else -> Unit } + when (pretrainedModels) { + PretrainedModels.CUSTOM -> Unit + else -> model.loadModel(getResourcePath(pretrainedModels.path)) + } } private fun getResourcePath(path: String): String = this::class.java.getResource(path).path @@ -46,12 +46,7 @@ class LanguageModelImpl( logger.info("Model saved as $name") } - companion object { - @JvmStatic - fun main(args: Array) { - val model = LanguageModelImpl(PretrainedModels.CARDS_AGAINST_WHITE, GenerationLevel.WORD, 10) - //model.createCustomModel("/cardsagainst_white.txt", "cardsagainst_white.cbor", false) - println(model.generateText("have a", 150, 0.1)) - } - } + override fun changeModelToCustom(path: String) = model.loadModel(path) + + override fun changeModelToPretrained(pretrainedModel: PretrainedModels) = model.loadModel(pretrainedModel.path) } \ No newline at end of file