diff --git a/riid/models/neural_nets/basic.py b/riid/models/neural_nets/basic.py index 705b32e..e416947 100644 --- a/riid/models/neural_nets/basic.py +++ b/riid/models/neural_nets/basic.py @@ -110,7 +110,7 @@ def fit(self, ss: SampleSet, batch_size: int = 200, epochs: int = 20, training_dataset = tf.data.Dataset.from_tensor_slices((spectra_tensor, labels_tensor)) training_dataset, validation_dataset = split_dataset( training_dataset, - left_size=validation_split, + right_size=validation_split, shuffle=True ) training_dataset = training_dataset.batch(batch_size=batch_size)