You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It would be useful to also see the performance of each SSL model against the purely supervised backbone run on the labeled data.
For example, TSVM vs pure SVM:
importnumpyasnpfromLAMDA_SSL.Dataset.Tabular.BreastCancerimportBreastCancerdataset=BreastCancer(labeled_size=0.1, stratified=True, shuffle=True)
labeled_X=dataset.labeled_Xlabeled_y=dataset.labeled_yunlabeled_X=dataset.unlabeled_Xunlabeled_y=dataset.unlabeled_yfromsklearnimportpreprocessingpre_transform=preprocessing.StandardScaler()
pre_transform.fit(np.vstack([labeled_X, unlabeled_X]))
labeled_X=pre_transform.transform(labeled_X)
unlabeled_X=pre_transform.transform(unlabeled_X)
fromLAMDA_SSL.Algorithm.Classification.TSVMimportTSVM# I tried using a range of Cl and Cu, starting from 15 and 0.0001 and then gradually # upping Cu and decreasing Cl. It didn't seem to make a difference?model=TSVM(Cl=1, Cu=1, kernel="linear")
model.fit(X=labeled_X, y=labeled_y, unlabeled_X=unlabeled_X)
pred_y=model.predict()
fromLAMDA_SSL.Evaluation.Classifier.AccuracyimportAccuracyscore=Accuracy().scoring(unlabeled_y, pred_y)
print(f"SSL TSVM score: {score}")
#> SSL TSVM score: 0.9609375# Compare with pure SVMfromsklearnimportsvmmodel_sl=svm.SVC()
model_sl.fit(labeled_X, labeled_y)
pred_sl=model_sl.predict(unlabeled_X)
score_sl=Accuracy().scoring(unlabeled_y, pred_sl)
print(f"SL SVM score: {score_sl}")
#> SL SVM score: 0.955078125
The text was updated successfully, but these errors were encountered:
It would be useful to also see the performance of each SSL model against the purely supervised backbone run on the labeled data.
For example, TSVM vs pure SVM:
The text was updated successfully, but these errors were encountered: