Skip to content

Commit

Permalink
Added swap model method
Browse files Browse the repository at this point in the history
  • Loading branch information
Lundez committed Jan 4, 2020
1 parent 7a91475 commit 5d57b33
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 14 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ enum class PretrainedModels {
}
```
##### GenerationLevel
No pretrained models exist for `CHAR` currently.
No pretrained models exist for `CHAR` currently. `CHAR` is not supported either.
```kotlin
enum class GenerationLevel {
WORD,
Expand Down
3 changes: 3 additions & 0 deletions src/main/kotlin/com/londogard/textgen/LanguageModel.kt
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,8 @@ interface LanguageModel {
* @param oneDocumentPerLine if true we'll treat each line as a document (i.e. pad around each line). If false it'll load the full file.
*/
fun createCustomModel(path: String, name: String, oneDocumentPerLine: Boolean = false)

fun changeModelToCustom(path: String)
fun changeModelToPretrained(pretrainedModel: PretrainedModels)
}

21 changes: 8 additions & 13 deletions src/main/kotlin/com/londogard/textgen/LanguageModelImpl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@ class LanguageModelImpl(
override val generationLevel: GenerationLevel,
override val n: Int = 10
) : LanguageModel {
private val model: BackendLM<String>
private val logger = logger().value
private val model: BackendLM<String>

init {
model = NGramWordLM(n)
when (pretrainedModels) {
PretrainedModels.CUSTOM -> Unit
else -> model.loadModel(getResourcePath(pretrainedModels.path))
}
when (generationLevel) {
GenerationLevel.CHAR -> throw NotImplementedError("Word Generation is not implemented yet.")
else -> Unit
}
when (pretrainedModels) {
PretrainedModels.CUSTOM -> Unit
else -> model.loadModel(getResourcePath(pretrainedModels.path))
}
}

private fun getResourcePath(path: String): String = this::class.java.getResource(path).path
Expand All @@ -46,12 +46,7 @@ class LanguageModelImpl(
logger.info("Model saved as $name")
}

companion object {
@JvmStatic
fun main(args: Array<String>) {
val model = LanguageModelImpl(PretrainedModels.CARDS_AGAINST_WHITE, GenerationLevel.WORD, 10)
//model.createCustomModel("/cardsagainst_white.txt", "cardsagainst_white.cbor", false)
println(model.generateText("have a", 150, 0.1))
}
}
override fun changeModelToCustom(path: String) = model.loadModel(path)

override fun changeModelToPretrained(pretrainedModel: PretrainedModels) = model.loadModel(pretrainedModel.path)
}

0 comments on commit 5d57b33

Please sign in to comment.