-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathsvg_render.py
141 lines (105 loc) · 4.77 KB
/
svg_render.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# -*- coding: utf-8 -*-
# Author: ximing xing
# Description: the main func of this project.
# Copyright (c) 2023, XiMing Xing.
import os
import sys
from functools import partial
from accelerate.utils import set_seed
import hydra
import omegaconf
sys.path.append(os.path.split(os.path.abspath(os.path.dirname(__file__)))[0])
from pytorch_svgrender.utils import render_batch_wrap, get_seed_range
METHODS = [
'diffvg',
'live',
'vectorfusion',
'clipasso',
'clipascene',
'diffsketcher',
'stylediffsketcher',
'clipdraw',
'styleclipdraw',
'wordasimage',
'clipfont',
'svgdreamer'
]
@hydra.main(version_base=None, config_path="conf", config_name='config')
def main(cfg: omegaconf.DictConfig):
"""
The project configuration is stored in './conf/config.yaml’
And method configurations are stored in './conf/x/’
"""
# print(omegaconf.OmegaConf.to_yaml(cfg))
flag = cfg.x.method
assert flag in METHODS, f"{flag} is not currently supported!"
# seed prepare
set_seed(cfg.seed)
seed_range = get_seed_range(cfg.srange) if cfg.multirun else None
# render function
render_batch_fn = partial(render_batch_wrap, cfg=cfg, seed_range=seed_range)
if flag == "diffvg": # img2svg
from pytorch_svgrender.pipelines.DiffVG_pipeline import DiffVGPipeline
pipe = DiffVGPipeline(cfg)
pipe.painterly_rendering(cfg.target)
elif flag == "live": # img2svg
from pytorch_svgrender.pipelines.LIVE_pipeline import LIVEPipeline
pipe = LIVEPipeline(cfg)
pipe.painterly_rendering(cfg.target)
elif flag == "vectorfusion": # text2svg
from pytorch_svgrender.pipelines.VectorFusion_pipeline import VectorFusionPipeline
if not cfg.multirun:
pipe = VectorFusionPipeline(cfg)
pipe.painterly_rendering(cfg.prompt)
else: # generate many SVG at once
render_batch_fn(pipeline=VectorFusionPipeline, text_prompt=cfg.prompt)
elif flag == "svgdreamer": # text2svg
from pytorch_svgrender.pipelines.SVGDreamer_pipeline import SVGDreamerPipeline
if not cfg.multirun:
pipe = SVGDreamerPipeline(cfg)
pipe.painterly_rendering(cfg.prompt)
else: # generate many SVG at once
render_batch_fn(pipeline=SVGDreamerPipeline, text_prompt=cfg.prompt, target_file=None)
elif flag == "wordasimage": # text2font
from pytorch_svgrender.pipelines.WordAsImage_pipeline import WordAsImagePipeline
pipe = WordAsImagePipeline(cfg)
pipe.painterly_rendering(cfg.x.word, cfg.prompt, cfg.x.optim_letter)
elif flag == "clipasso": # img2sketch
from pytorch_svgrender.pipelines.CLIPasso_pipeline import CLIPassoPipeline
pipe = CLIPassoPipeline(cfg)
pipe.painterly_rendering(cfg.target)
elif flag == 'clipascene':
from pytorch_svgrender.pipelines.CLIPascene_pipeline import CLIPascenePipeline
pipe = CLIPascenePipeline(cfg)
pipe.painterly_rendering(cfg.target)
elif flag == "clipdraw": # text2svg
from pytorch_svgrender.pipelines.CLIPDraw_pipeline import CLIPDrawPipeline
pipe = CLIPDrawPipeline(cfg)
pipe.painterly_rendering(cfg.prompt)
elif flag == "clipfont": # text and font to font
from pytorch_svgrender.pipelines.CLIPFont_pipeline import CLIPFontPipeline
if not cfg.multirun:
pipe = CLIPFontPipeline(cfg)
pipe.painterly_rendering(svg_path=cfg.target, prompt=cfg.prompt)
else: # generate many SVG at once
render_batch_fn(pipeline=CLIPFontPipeline, svg_path=cfg.target, prompt=cfg.prompt)
elif flag == "styleclipdraw": # text to stylized svg
from pytorch_svgrender.pipelines.StyleCLIPDraw_pipeline import StyleCLIPDrawPipeline
pipe = StyleCLIPDrawPipeline(cfg)
pipe.painterly_rendering(cfg.prompt, style_fpath=cfg.target)
elif flag == "diffsketcher": # text2sketch
from pytorch_svgrender.pipelines.DiffSketcher_pipeline import DiffSketcherPipeline
if not cfg.multirun:
pipe = DiffSketcherPipeline(cfg)
pipe.painterly_rendering(cfg.prompt)
else: # generate many SVG at once
render_batch_fn(pipeline=DiffSketcherPipeline, prompt=cfg.prompt)
elif flag == "stylediffsketcher": # text2sketch + style transfer
from pytorch_svgrender.pipelines.DiffSketcher_stylized_pipeline import StylizedDiffSketcherPipeline
if not cfg.multirun:
pipe = StylizedDiffSketcherPipeline(cfg)
pipe.painterly_rendering(cfg.prompt, style_fpath=cfg.target)
else: # generate many SVG at once
render_batch_fn(pipeline=StylizedDiffSketcherPipeline, prompt=cfg.prompt, style_fpath=cfg.style_file)
if __name__ == '__main__':
main()