diff --git a/python/xorbits/_mars/learn/contrib/xgboost/tracker.py b/python/xorbits/_mars/learn/contrib/xgboost/tracker.py index aeaca07cb..b2f02c078 100644 --- a/python/xorbits/_mars/learn/contrib/xgboost/tracker.py +++ b/python/xorbits/_mars/learn/contrib/xgboost/tracker.py @@ -241,7 +241,7 @@ def worker_envs(self) -> Dict[str, Union[str, int]]: get environment variables for workers can be passed in as args or envs """ - return {"DMLC_TRACKER_URI": self.host_ip, "DMLC_TRACKER_PORT": self.port} + return {"rabbit_tracker_uri": self.host_ip, "rabbit_tracker_port": self.port} def _get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]: tree_map: _TreeMap = {} diff --git a/python/xorbits/_mars/learn/contrib/xgboost/train.py b/python/xorbits/_mars/learn/contrib/xgboost/train.py index 6ea49de4d..5ca4476db 100644 --- a/python/xorbits/_mars/learn/contrib/xgboost/train.py +++ b/python/xorbits/_mars/learn/contrib/xgboost/train.py @@ -224,18 +224,11 @@ def execute(cls, ctx, op: "XGBTrain"): for arg in rabit_args ] parsed = {} - args_map_dmlc2rabit = { - "DMLC_TRACKER_URI": "rabit_tracker_uri", - "DMLC_TRACKER_PORT": "rabit_tracker_port", - "DMLC_NUM_WORKER": "rabit_num_worker", - } if rabit_args: for arg in rabit_args: kv = arg.decode().split("=") if len(kv) == 2: - parsed[args_map_dmlc2rabit[kv[0]]] = ( - int(kv[1]) if kv[0] == "DMLC_TRACKER_PORT" else kv[1] - ) + parsed[kv[0]] = kv[1] collective.init(**parsed) try: logger.debug( diff --git a/test.py b/test.py new file mode 100644 index 000000000..85d49e5c2 --- /dev/null +++ b/test.py @@ -0,0 +1,57 @@ +import pytest +from python.xorbits import numpy as np +from python.xorbits import pandas as pd +from python.xorbits import xgboost as xxgb +from python.xorbits._mars.core.entity.objects import ObjectData + + +X = np.random.rand(100, 10) +X_df = pd.DataFrame(X) +y = np.random.randint(0, 2, 100) + +classifier = xxgb.XGBClassifier(verbosity=1, n_estimators=2) + +classifier.fit(X_df, y, eval_set=[(X_df, y)]) +pred = classifier.predict(X_df) + +assert pred.ndim == 1 +assert pred.shape[0] == len(X_df) + +history = classifier.evals_result() + +assert isinstance(history, dict) + +assert list(history)[0] == "validation_0" + +prob = classifier.predict_proba(X_df) + +assert prob.shape[0] == X_df.shape[0] + +assert len(pred) == len(y) +assert set(pred.to_numpy().to_numpy()).issubset({0, 1}) + +# test weight +weights = [ + np.random.rand(X_df.shape[0]), + pd.Series(np.random.rand(X_df.shape[0])), + pd.DataFrame(np.random.rand(X_df.shape[0])), +] +y_df = pd.DataFrame(y) +for weight in weights: + classifier.fit(X_df, y_df, sample_weight=weight) + prediction = classifier.predict(X_df) + + assert prediction.ndim == 1 + assert prediction.shape[0] == len(X_df) + +# should raise error if weight.ndim > 1 +with pytest.raises(ValueError): + classifier.fit(X_df, y_df, sample_weight=np.random.rand(1, 1)) + +# test wrong argument +with pytest.raises(TypeError): + classifier.fit(X_df, y, wrong_param=1) + +# test wrong attribute +with pytest.raises(AttributeError): + classifier.wrong_attribute()