diff --git a/build.gradle.kts b/build.gradle.kts index f776b54..1692226 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -21,7 +21,7 @@ plugins { group = "com.londogard" -version = "1.0-beta" +version = "1.0.1-beta" repositories { mavenCentral() diff --git a/src/main/kotlin/com/londogard/textgen/backends/BackendLM.kt b/src/main/kotlin/com/londogard/textgen/backends/BackendLM.kt index d98e369..2164d14 100644 --- a/src/main/kotlin/com/londogard/textgen/backends/BackendLM.kt +++ b/src/main/kotlin/com/londogard/textgen/backends/BackendLM.kt @@ -7,12 +7,13 @@ import java.io.File @ImplicitReflectionSerializer abstract class BackendLM { - protected abstract val mapSerializer: KSerializer, Double>> - private val mapCharSerializer = (Char::class.serializer().list to Double::class.serializer()).map + protected abstract val mapSerializer: KSerializer>> + protected val stringSerializer = String::class.serializer() + protected val doubleSerializer = Double::class.serializer() private val cborSerializer = Cbor.plain private val padStart: Char = '\u0002' private val padEnd: Char = '\u0003' - protected abstract var internalLanguageModel: Map, Double> + protected abstract var internalLanguageModel: Map> // Map> protected abstract val n: Int val padEndList = List(n) { padEnd.toString() } val padStartList = List(n) { padStart.toString() } @@ -26,15 +27,15 @@ abstract class BackendLM { private fun getResource(path: String): InputStream = this::class.java.getResourceAsStream(path) - protected fun serializeMapToFile(name: String, map: Map, Double>): Unit = cborSerializer + protected fun serializeMapToFile(name: String, map: Map>): Unit = cborSerializer .dump(mapSerializer, map) .let { File(name).writeBytes(it) } - protected fun readSerializedMapFromFile(name: String): Map, Double> = File(name) + protected fun readSerializedMapFromFile(name: String): Map> = File(name) .readBytes() .let { cborSerializer.load(mapSerializer, it) } - protected fun readSerializedMapFromResource(name: String): Map, Double> = getResource(name) + protected fun readSerializedMapFromResource(name: String): Map> = getResource(name) .readBytes() .let { cborSerializer.load(mapSerializer, it) } } \ No newline at end of file diff --git a/src/main/kotlin/com/londogard/textgen/backends/NGramWordLM.kt b/src/main/kotlin/com/londogard/textgen/backends/NGramWordLM.kt index bc29979..94a4a93 100644 --- a/src/main/kotlin/com/londogard/textgen/backends/NGramWordLM.kt +++ b/src/main/kotlin/com/londogard/textgen/backends/NGramWordLM.kt @@ -4,16 +4,16 @@ import com.londogard.textgen.NGram import com.londogard.textgen.ngramNormalize import kotlinx.serialization.* import java.io.File +import kotlin.math.min import kotlin.math.pow import kotlin.random.Random @ImplicitReflectionSerializer class NGramWordLM( override val n: Int, - override var internalLanguageModel: Map, Double> = emptyMap(), - override val mapSerializer: KSerializer, Double>> = (String::class.serializer().list to Double::class.serializer()).map -) : - BackendLM() { + override var internalLanguageModel: Map> = emptyMap()) : BackendLM() { + private val stringDouble = (stringSerializer to doubleSerializer).map + override val mapSerializer: KSerializer>> = (stringSerializer to stringDouble).map override fun predictNext(input: String, temperature: Double): String = TODO("Implement this, don't forget to not remove \n etc") @@ -52,11 +52,20 @@ class NGramWordLM( val totalCount = internalModel.filterKeys { it.size == 1 }.values.sum() - internalLanguageModel = internalModel + val precomputedModel = internalModel .mapValues { (key, value) -> (value / (internalModel[key.dropLast(1)] ?: totalCount)) } + internalLanguageModel = precomputedModel + .entries + .groupBy( { it.key.dropLast(1).joinToString(" ") } , { it.key.last() to it.value }) + .mapValues { it.value.toMap() } + //internalLanguageModel = internalModel + // .mapValues { (key, value) -> + // (0.4.pow(n - key.size)) * value / (internalModel[key.dropLast(1)] ?: totalCount) + // } + // TODO add Kneser-Ney Smooth //val discountByN = (1..n).map { i -> // val ndValues = modelByN[i]?.filterValues { it in listOf(1.0, 2.0) } ?: emptyMap() @@ -71,17 +80,13 @@ class NGramWordLM( override fun predictNext(input: List, temperature: Double): String { val history = input.takeLast(n - 1) - val options = (n downTo 1) - .asSequence() - .map { i -> - val discount = 0.4.pow(n.toDouble() - i) - - internalLanguageModel - .filterKeys { it.size == i && (it.size == 1 || it.take(i - 1) == history.takeLast(i - 1)) } - .mapValues { it.value * discount }.entries - } - .map { it.sortedByDescending { l -> l.value } } - .flatten() + val keys = (min(input.size, n) downTo 0).map { + input.takeLast(it).joinToString(" ") + } + + val options = keys.asSequence() + .mapNotNull { key -> internalLanguageModel[key]?.entries } + .flatMap { it.sortedByDescending { subEntry -> subEntry.value }.take(10).asSequence() } .take(10) .toList() @@ -93,6 +98,6 @@ class NGramWordLM( selection -= it.value selection > 0 } - .first().key.last() + .first().key } } \ No newline at end of file diff --git a/src/main/resources/models/cardsagainst_black.cbor b/src/main/resources/models/cardsagainst_black.cbor index 83f684c..f13706f 100644 Binary files a/src/main/resources/models/cardsagainst_black.cbor and b/src/main/resources/models/cardsagainst_black.cbor differ diff --git a/src/main/resources/models/cardsagainst_white.cbor b/src/main/resources/models/cardsagainst_white.cbor index d89c2b8..acab7cf 100644 Binary files a/src/main/resources/models/cardsagainst_white.cbor and b/src/main/resources/models/cardsagainst_white.cbor differ diff --git a/src/main/resources/models/shakespeare.cbor b/src/main/resources/models/shakespeare.cbor index 490d000..a006c44 100644 Binary files a/src/main/resources/models/shakespeare.cbor and b/src/main/resources/models/shakespeare.cbor differ