diff --git a/test/scicloj/ml/xgboost_test.clj b/test/scicloj/ml/xgboost_test.clj index e196b42..6935c72 100644 --- a/test/scicloj/ml/xgboost_test.clj +++ b/test/scicloj/ml/xgboost_test.clj @@ -228,46 +228,71 @@ first (tc/drop-missing) (text/->tfidf) - (tc/rename-columns {:meta :label}) - (ds-mod/set-inference-target [:label])) + (tc/rename-columns {:meta :label})) + rnd-documents (shuffle (range 1000)) + train-documents (into #{} (take 800 rnd-documents)) + test-documents (into #{} (take-last 200 rnd-documents)) - - n-sparse-columns (inc (apply max (reviews :token-idx))) - model - (ml/train reviews {:model-type :xgboost/classification - :sparse-column :tfidf - :seed 123 - :num-class 5 - :n-sparse-columns n-sparse-columns}) - + train-reviews + (-> reviews + (tc/select-rows (fn [row] (contains? train-documents (:document row)))) + (ds-mod/set-inference-target :label)) + trueth-train + (-> train-reviews + (tc/select-columns [:document :label]) + (tc/unique-by [:document :label]) + (tc/order-by :document) + :label) - test-reviews reviews - - raw-prediction - (-> - (ml/predict test-reviews model) - (tc/select-columns [:label :document]) - ) - - prediction - (-> - raw-prediction - (tc/order-by :document)) + test-reviews + (-> reviews + (tc/select-rows (fn [row] (contains? test-documents (:document row))))) - trueth + trueth-test (-> test-reviews (tc/select-columns [:document :label]) (tc/unique-by [:document :label]) (tc/order-by :document) - :label - ) + :label) + + test-review-clean + (-> test-reviews + (tc/drop-columns [:label])) + + n-sparse-columns (inc (apply max (reviews :token-idx))) + model + (ml/train train-reviews {:model-type :xgboost/classification + :sparse-column :tfidf + :seed 123 + :num-class 5 + :n-sparse-columns n-sparse-columns}) + + + prediction-train + (-> + (ml/predict train-reviews model) + (tc/select-columns [:label :document]) + (tc/order-by :document)) + + prediction-test + (-> + (ml/predict test-review-clean model) + (tc/select-columns [:label :document]) + (tc/order-by :document)) + ] (is (< 0.95 (loss/classification-accuracy - (mapv int (:label prediction)) - (vec trueth)))))) + (mapv int (:label prediction-train)) + (vec trueth-train)))) + + (is (< 0.55 + (loss/classification-accuracy + (mapv int (:label prediction-test)) + (vec trueth-test)))) + ))