Skip to content

Commit

Permalink
add ray unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
Cathy0908 committed Dec 23, 2024
1 parent bba1f38 commit 4f65b09
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 1 deletion.
2 changes: 1 addition & 1 deletion data_juicer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def init_setup_from_cfg(cfg: Namespace):
cfg.np = sys_cpu_count
logger.warning(
f'Number of processes `np` is not set, '
f'Set it to cpu count [{sys_cpu_count}] as default value.')
f'set it to cpu count [{sys_cpu_count}] as default value.')
if cfg.np > sys_cpu_count:
logger.warning(f'Number of processes `np` is set as [{cfg.np}], which '
f'is larger than the cpu count [{sys_cpu_count}]. Due '
Expand Down
97 changes: 97 additions & 0 deletions tests/tools/test_process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,102 @@ def test_status_code_1(self):
self.assertFalse(osp.exists(tmp_out_path))


class ProcessDataRayTest(DataJuicerTestCaseBase):

def setUp(self):
super().setUp()

self._auto_create_ray_cluster()

self.tmp_dir = tempfile.TemporaryDirectory().name
if not osp.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

def _auto_create_ray_cluster(self):
if not subprocess.call('ray status', shell=True):
# ray cluster already exists, return
self.tmp_ray_cluster = False
return

self.tmp_ray_cluster = True
head_port = '6379'
head_addr = '127.0.0.1'
rank = int(os.environ.get('RANK', 0))

if rank == 0:
cmd = f"ray start --head --port={head_port} --node-ip-address={head_addr}"
else:
cmd = f"ray start --address={head_addr}:{head_port}"

print(f"current rank: {rank}; execute cmd: {cmd}")

result = subprocess.call(cmd, shell=True)
if result != 0:
raise subprocess.CalledProcessError(result, cmd)

def _close_ray_cluster(self):
subprocess.call('ray stop', shell=True)

def tearDown(self):
super().tearDown()

if osp.exists(self.tmp_dir):
shutil.rmtree(self.tmp_dir)

import ray
ray.shutdown()

if self.tmp_ray_cluster:
self._close_ray_cluster()

def test_ray_image(self):
tmp_yaml_file = osp.join(self.tmp_dir, 'config_0.yaml')
tmp_out_path = osp.join(self.tmp_dir, 'output_0.json')
text_keys = 'text'

data_path = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))),
'demos', 'data', 'demo-dataset-images.jsonl')
yaml_config = {
'dataset_path': data_path,
'executor_type': 'ray',
'ray_address': 'auto',
'text_keys': text_keys,
'image_key': 'images',
'export_path': tmp_out_path,
'process': [
{
'image_nsfw_filter': {
'hf_nsfw_model': 'Falconsai/nsfw_image_detection',
'trust_remote_code': True,
'score_threshold': 0.5,
'any_or_all': 'any',
'mem_required': '8GB'
},
'image_aspect_ratio_filter':{
'min_ratio': 0.5,
'max_ratio': 2.0
}
}
]
}

with open(tmp_yaml_file, 'w') as file:
yaml.dump(yaml_config, file)

status_code = subprocess.call(
f'python tools/process_data.py --config {tmp_yaml_file}', shell=True)

self.assertEqual(status_code, 0)
self.assertTrue(osp.exists(tmp_out_path))

import ray
res_ds = ray.data.read_json(tmp_out_path)
res_ds = res_ds.to_pandas().to_dict(orient='records')

self.assertEqual(len(res_ds), 3)
for item in res_ds:
self.assertIn('aspect_ratios', item['__dj__stats__'])


if __name__ == '__main__':
unittest.main()

0 comments on commit 4f65b09

Please sign in to comment.