Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
hucorz committed Sep 1, 2024
1 parent 25d8de5 commit e187f16
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 9 deletions.
2 changes: 1 addition & 1 deletion python/xorbits/_mars/learn/contrib/xgboost/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def worker_envs(self) -> Dict[str, Union[str, int]]:
get environment variables for workers
can be passed in as args or envs
"""
return {"DMLC_TRACKER_URI": self.host_ip, "DMLC_TRACKER_PORT": self.port}
return {"rabbit_tracker_uri": self.host_ip, "rabbit_tracker_port": self.port}

def _get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]:
tree_map: _TreeMap = {}
Expand Down
9 changes: 1 addition & 8 deletions python/xorbits/_mars/learn/contrib/xgboost/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,18 +224,11 @@ def execute(cls, ctx, op: "XGBTrain"):
for arg in rabit_args
]
parsed = {}
args_map_dmlc2rabit = {
"DMLC_TRACKER_URI": "rabit_tracker_uri",
"DMLC_TRACKER_PORT": "rabit_tracker_port",
"DMLC_NUM_WORKER": "rabit_num_worker",
}
if rabit_args:
for arg in rabit_args:
kv = arg.decode().split("=")
if len(kv) == 2:
parsed[args_map_dmlc2rabit[kv[0]]] = (
int(kv[1]) if kv[0] == "DMLC_TRACKER_PORT" else kv[1]
)
parsed[kv[0]] = kv[1]
collective.init(**parsed)
try:
logger.debug(
Expand Down
57 changes: 57 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pytest
from python.xorbits import numpy as np
from python.xorbits import pandas as pd
from python.xorbits import xgboost as xxgb
from python.xorbits._mars.core.entity.objects import ObjectData


X = np.random.rand(100, 10)
X_df = pd.DataFrame(X)
y = np.random.randint(0, 2, 100)

classifier = xxgb.XGBClassifier(verbosity=1, n_estimators=2)

classifier.fit(X_df, y, eval_set=[(X_df, y)])
pred = classifier.predict(X_df)

assert pred.ndim == 1
assert pred.shape[0] == len(X_df)

history = classifier.evals_result()

assert isinstance(history, dict)

assert list(history)[0] == "validation_0"

prob = classifier.predict_proba(X_df)

assert prob.shape[0] == X_df.shape[0]

assert len(pred) == len(y)
assert set(pred.to_numpy().to_numpy()).issubset({0, 1})

# test weight
weights = [
np.random.rand(X_df.shape[0]),
pd.Series(np.random.rand(X_df.shape[0])),
pd.DataFrame(np.random.rand(X_df.shape[0])),
]
y_df = pd.DataFrame(y)
for weight in weights:
classifier.fit(X_df, y_df, sample_weight=weight)
prediction = classifier.predict(X_df)

assert prediction.ndim == 1
assert prediction.shape[0] == len(X_df)

# should raise error if weight.ndim > 1
with pytest.raises(ValueError):
classifier.fit(X_df, y_df, sample_weight=np.random.rand(1, 1))

# test wrong argument
with pytest.raises(TypeError):
classifier.fit(X_df, y, wrong_param=1)

# test wrong attribute
with pytest.raises(AttributeError):
classifier.wrong_attribute()

0 comments on commit e187f16

Please sign in to comment.