Skip to content

Commit

Permalink
update unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
Cathy0908 committed Dec 25, 2024
1 parent 0a2ddd6 commit 73c7316
Showing 1 changed file with 5 additions and 35 deletions.
40 changes: 5 additions & 35 deletions tests/tools/test_process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import subprocess
import tempfile
import unittest
import uuid
import yaml

from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
Expand Down Expand Up @@ -45,8 +46,7 @@ def setUp(self):
super().setUp()

self.tmp_dir = tempfile.TemporaryDirectory().name
if not osp.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
os.makedirs(self.tmp_dir, exist_ok=True)

def tearDown(self):
super().tearDown()
Expand Down Expand Up @@ -101,36 +101,9 @@ class ProcessDataRayTest(DataJuicerTestCaseBase):
def setUp(self):
super().setUp()

# self._auto_create_ray_cluster()
self.tmp_dir = f'/workspace/tmp/{self.__class__.__name__}'
if not osp.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

def _auto_create_ray_cluster(self):
try:
# ray cluster already exists, return
run_in_subprocess('ray status')
self.tmp_ray_cluster = False
return
except:
pass

self.tmp_ray_cluster = True
head_port = '6379'
head_addr = '127.0.0.1'
rank = int(os.environ.get('RANK', 0))

if rank == 0:
cmd = f"ray start --head --port={head_port} --node-ip-address={head_addr}"
else:
cmd = f"ray start --address={head_addr}:{head_port}"

print(f"current rank: {rank}; execute cmd: {cmd}")

run_in_subprocess(cmd)

def _close_ray_cluster(self):
run_in_subprocess('ray stop')
cur_dir = osp.dirname(osp.abspath(__file__))
self.tmp_dir = osp.join(cur_dir, f'tmp_{uuid.uuid4().hex}')
os.makedirs(self.tmp_dir, exist_ok=True)

def tearDown(self):
super().tearDown()
Expand All @@ -141,9 +114,6 @@ def tearDown(self):
import ray
ray.shutdown()

# if self.tmp_ray_cluster:
# self._close_ray_cluster()

def test_ray_image(self):
tmp_yaml_file = osp.join(self.tmp_dir, 'config_0.yaml')
tmp_out_path = osp.join(self.tmp_dir, 'output_0.json')
Expand Down

0 comments on commit 73c7316

Please sign in to comment.