Skip to content

Commit

Permalink
reliable communicate dmatrix order in predict
Browse files Browse the repository at this point in the history
  • Loading branch information
behrica committed Nov 3, 2024
1 parent 568b5d1 commit bf7c4ca
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 38 deletions.
39 changes: 11 additions & 28 deletions src/scicloj/ml/xgboost.clj
Original file line number Diff line number Diff line change
Expand Up @@ -447,29 +447,16 @@ subsample may be set to as low as 0.1 without loss of model accuracy. Note that
(defn- predict
[feature-ds thawed-model {:keys [target-columns target-categorical-maps options]}]
(let [sparse-column-or-nil (:sparse-column options)
_ (def feature-ds feature-ds)



;; prediction-index->document--map
;; (zipmap
;; (range)
;; (-> feature-ds :document distinct sort))


dmatrix (:dmatrix (->dmatrix feature-ds nil sparse-column-or-nil (:n-sparse-columns options)))
dmatrix-context (->dmatrix feature-ds nil sparse-column-or-nil (:n-sparse-columns options))
dmatrix (:dmatrix dmatrix-context)
prediction (.predict ^Booster thawed-model dmatrix)

_ (def prediction prediction)
predict-tensor
(->> prediction
(dtt/->tensor))
target-cname (first target-columns)



prediction-df

(if (multiclass-objective? (options->objective options))
(->
(model/finalize-classification predict-tensor
Expand All @@ -479,21 +466,17 @@ subsample may be set to as low as 0.1 without loss of model accuracy. Note that
(tech.v3.dataset.modelling/probability-distributions->label-column target-cname)
(ds/update-column (first target-columns)
#(vary-meta % assoc :column-type :prediction)))
(model/finalize-regression predict-tensor target-cname))

;document-ids
;(map prediction-index->document--map (range (ds/row-count prediction-df)))

]
(model/finalize-regression predict-tensor target-cname))]

prediction-df
;(def prediction-df prediction-df)


;; (assoc prediction-df
;; :document document-ids)

))
(if (:dmatrix-order dmatrix-context)
(assoc prediction-df
:document
(-> dmatrix-context
:dmatrix-order
(tc/order-by :row-nr)
:document))
prediction-df)))



Expand Down
1 change: 0 additions & 1 deletion test/scicloj/ml/text_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@

bow-train
(-> ds-train

text/->tfidf
(tc/rename-columns {:meta :label}))

Expand Down
15 changes: 6 additions & 9 deletions test/scicloj/ml/xgboost_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@
:n-sparse-columns n-sparse-columns})


_ (def model model)

test-reviews reviews

Expand All @@ -252,25 +251,23 @@
(tc/select-columns [:label :document])
)

_ (def raw-prediction raw-prediction)
prediction raw-prediction
prediction
(->
raw-prediction
(tc/order-by :document))

trueth
(-> test-reviews
(tc/select-columns [:document :label])
(tc/unique-by [:document :label])
(tc/order-by :document)
:label
)

]

(def prediction prediction)
(def trueth trueth)

(is (< 0.95
(loss/classification-accuracy
(mapv int (:label prediction))
(vec trueth)


)))))
(vec trueth))))))

0 comments on commit bf7c4ca

Please sign in to comment.