Skip to content

Commit

Permalink
For readthedocs (#373)
Browse files Browse the repository at this point in the history
* feat(gradio): auto build feature store
  • Loading branch information
tpoisonooo authored Aug 28, 2024
1 parent 1fc92d9 commit dcdebc0
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 7 deletions.
43 changes: 38 additions & 5 deletions huixiangdou/gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
from huixiangdou.primitive import Query
from huixiangdou.service import ErrorCode, SerialPipeline, ParallelPipeline, llm_serve, start_llm_server
import json
from datetime import datetime

def ymd():
now = datetime.now()
date_string = now.strftime("%Y-%m-%d")
if not os.path.exists(date_string):
os.makedirs(date_string)
return date_string

def parse_args():
"""Parse args."""
Expand All @@ -21,10 +29,10 @@ def parse_args():
type=str,
default='workdir',
help='Working directory.')
parser.add_argument('--pipeline-count', type=int, default=2, help='Support user choosing all pipeline types.')
parser.add_argument('--pipeline-count', type=int, default=1, help='Support user choosing all pipeline types.')
parser.add_argument(
'--config_path',
default='config.ini',
default='config-cpu.ini',
type=str,
help='SerialPipeline configuration path. Default value is config.ini')
parser.add_argument('--standalone',
Expand All @@ -36,9 +44,9 @@ def parse_args():
dest='standalone', # 指定与上面参数相同的目标
help='Do not auto deploy required Hybrid LLM Service.')
parser.add_argument('--placeholder', type=str, default='How to install HuixiangDou ?', help='Placeholder for user query.')
parser.add_argument('--image', action='store_true', default=True, help='')
parser.add_argument('--no-image', action='store_false', dest='image', help='Close some components for readthedocs.')
parser.add_argument('--image', action='store_false', default=True, help='')
parser.add_argument('--theme', type=str, default='soft', help='Gradio theme, default value is `soft`. Open https://www.gradio.app/guides/theming-guide for all themes.')

args = parser.parse_args()
return args

Expand Down Expand Up @@ -93,7 +101,7 @@ async def predict(text:str, image:str):
global paralle_assistant

with open('query.txt', 'a') as f:
f.write(json.dumps({'data': text}))
f.write(json.dumps({'data': text, 'date': ymd()}, ensure_ascii=False))
f.write('\n')

if image is not None:
Expand Down Expand Up @@ -145,8 +153,33 @@ async def predict(text:str, image:str):

yield sentence

def download_and_unzip(main_args):
zip_filepath = os.path.join(main_args.feature_local, 'workdir.zip')
main_args.work_dir = os.path.join(main_args.feature_local, 'workdir')
logger.info(f'assign {main_args.work_dir} to args.work_dir')

download_cmd = f'wget -O {zip_filepath} {main_args.feature_url}'
os.system(download_cmd)

if not os.path.exists(zip_filepath):
raise Exception(f'zip filepath {zip_filepath} not exist.')

unzip_cmd = f'unzip -o {zip_filepath} -d {main_args.feature_local}'
os.system(unzip_cmd)
if not os.path.exists(main_args.work_dir):
raise Exception(f'feature dir {zip_dir} not exist.')

def build_feature_store(main_args):
if os.path.exists('workdir'):
logger.warning('feature_store `workdir` already exist, skip')
return
logger.info('start build feature_store..')
os.system('python3 -m huixiangdou.service.feature_store --config_path {}'.format(main_args.config_path))

if __name__ == '__main__':
main_args = parse_args()
build_feature_store(main_args)

show_image = True
radio_options = ["chat_with_repo"]

Expand Down
1 change: 1 addition & 0 deletions huixiangdou/primitive/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ def nested_split_markdown(filepath: str,
f'image cannot access. file: {filepath}, image path: {image_path}'
)

# logger.info('{} text_chunks, {} image_chunks'.format(len(text_chunks), len(image_chunks)))
return text_chunks + image_chunks

def clean_md(text: str):
Expand Down
8 changes: 8 additions & 0 deletions huixiangdou/service/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,23 @@ def analyze(self, chunks: List[Chunk]):

text_lens = []
token_lens = []
text_chunk_count = 0
image_chunk_count = 0

if self.embedder is None:
logger.info('self.embedder is None, skip `anaylze_output`')
return
for chunk in chunks:
if chunk.modal == 'image':
image_chunk_count += 1
elif chunk.modal == 'text':
text_chunk_count += 1

content = chunk.content_or_path
text_lens.append(len(content))
token_lens.append(self.embedder.token_length(content))

logger.info('text_chunks {}, image_chunks {}'.format(text_chunk_count, image_chunk_count))
logger.info('text histogram, {}'.format(histogram(text_lens)))
logger.info('token histogram, {}'.format(
histogram(token_lens)))
Expand Down
4 changes: 2 additions & 2 deletions huixiangdou/service/parallel_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async def process_coroutine(self, sess: Session) -> Session:
"""Try get reply with text2vec & rerank model."""

# retrieve from knowledge base
sess.parallel_chunks = await asyncio.to_thread(self.retriever.text2vec_retrieve, sess.query)
sess.parallel_chunks = await asyncio.to_thread(self.retriever.text2vec_retrieve, sess.query)
# sess.parallel_chunks = self.retriever.text2vec_retrieve(query=sess.query.text)
return sess

Expand Down Expand Up @@ -250,7 +250,7 @@ def __init__(self, work_dir: str, config_path: str):
config_path (str): The location of the configuration file.
"""
self.llm = ChatClient(config_path=config_path)
self.retriever = CacheRetriever(config_path=config_path).get()
self.retriever = CacheRetriever(config_path=config_path).get(work_dir=work_dir)

self.config_path = config_path
self.config = None
Expand Down

0 comments on commit dcdebc0

Please sign in to comment.