Skip to content

Commit

Permalink
优化权息数据加载速度
Browse files Browse the repository at this point in the history
  • Loading branch information
fasiondog committed Jan 5, 2024
1 parent 0b915d4 commit 41866b7
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 15 deletions.
20 changes: 6 additions & 14 deletions hikyuu_cpp/hikyuu/StockManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,23 +538,15 @@ void StockManager::loadAllHolidays() {

void StockManager::loadAllStockWeights() {
HKU_INFO("Loading stock weight...");
ThreadPool tg; // 这里不用全局的线程池,可以避免在初始化后立即reload导致过长的等待
std::vector<std::future<void>> task_list;
auto all_stkweight_dict = m_baseInfoDriver->getAllStockWeightList();
std::lock_guard<std::mutex> lock(*m_stockDict_mutex);
for (auto iter = m_stockDict.begin(); iter != m_stockDict.end(); ++iter) {
task_list.push_back(tg.submit([=]() mutable {
auto weight_iter = all_stkweight_dict.find(iter->first);
if (weight_iter != all_stkweight_dict.end()) {
Stock& stock = iter->second;
StockWeightList weightList = m_baseInfoDriver->getStockWeightList(
stock.market(), stock.code(), Datetime::min(), Null<Datetime>());
if (stock.m_data) {
std::lock_guard<std::mutex> lock(stock.m_data->m_weight_mutex);
stock.m_data->m_weightList.swap(weightList);
}
}));
}
// 权息信息如果不等待加载完毕,在数据加载期间进行计算可能导致复权错误,所以这里需要等待
for (auto& task : task_list) {
task.get();
std::lock_guard<std::mutex> lock(stock.m_data->m_weight_mutex);
stock.m_data->m_weightList.swap(weight_iter->second);
}
}
}

Expand Down
5 changes: 5 additions & 0 deletions hikyuu_cpp/hikyuu/data_driver/BaseInfoDriver.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ class HKU_API BaseInfoDriver {
virtual StockWeightList getStockWeightList(const string& market, const string& code,
Datetime start, Datetime end);

virtual unordered_map<string, StockWeightList> getAllStockWeightList() {
unordered_map<string, StockWeightList> ret;
return ret;
}

/**
* 获取当前财务信息
* @param market 市场标识
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,48 @@ StockWeightList MySQLBaseInfoDriver::getStockWeightList(const string &market, co
return result;
}

unordered_map<string, StockWeightList> MySQLBaseInfoDriver::getAllStockWeightList() {
unordered_map<string, StockWeightList> result;
HKU_ASSERT(m_pool);

try {
auto con = m_pool->getConnect();
HKU_CHECK(con, "Failed fetch connect!");

vector<StockWeightTableView> view;
con->batchLoadView(
view,
"SELECT a.id AS id, concat(market.market, stock.code) AS market_code, a.date, "
"a.countAsGift*0.0001 AS countAsGift, a.countForSell*0.0001 AS countForSell, "
"a.priceForSell*0.001 AS priceForSell, a.bonus*0.001,a.countOfIncreasement*0.0001 AS "
"countOfIncreasement, a.totalCount AS totalCount, a.freeCount AS freeCount FROM "
"stkweight AS a, stock, market WHERE a.stockid=stock.stockid AND "
"market.marketid=stock.marketid ORDER BY a.stockid, a.date");

for (const auto &record : view) {
auto iter = result.find(record.market_code);
if (iter == result.end()) {
auto in_iter = result.insert(std::make_pair(record.market_code, StockWeightList()));
if (in_iter.second) {
iter = in_iter.first;
}
}
iter->second.emplace_back(StockWeight(
Datetime(record.date), record.countAsGift, record.countForSell, record.priceForSell,
record.bonus, record.countOfIncreasement, record.totalCount, record.freeCount));
}

} catch (std::exception &e) {
HKU_FATAL("load StockWeight table failed! {}", e.what());
return result;
} catch (...) {
HKU_FATAL("load StockWeight table failed!");
return result;
}

return result;
}

vector<StockInfo> MySQLBaseInfoDriver::getAllStockInfo() {
vector<StockInfo> result;
HKU_ERROR_IF_RETURN(!m_pool, result, "Connect pool ptr is null!");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class MySQLBaseInfoDriver : public BaseInfoDriver {
virtual Parameter getFinanceInfo(const string& market, const string& code) override;
virtual StockWeightList getStockWeightList(const string& market, const string& code,
Datetime start, Datetime end) override;
virtual unordered_map<string, StockWeightList> getAllStockWeightList() override;
virtual MarketInfo getMarketInfo(const string& market) override;
virtual StockTypeInfo getStockTypeInfo(uint32_t type) override;
virtual StockInfo getStockInfo(string market, const string& code) override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,48 @@ StockWeightList SQLiteBaseInfoDriver::getStockWeightList(const string& market, c
return result;
}

unordered_map<string, StockWeightList> SQLiteBaseInfoDriver::getAllStockWeightList() {
unordered_map<string, StockWeightList> result;
HKU_ASSERT(m_pool);

try {
auto con = m_pool->getConnect();
HKU_CHECK(con, "Failed fetch connect!");

vector<StockWeightTableView> view;
con->batchLoadView(
view,
"SELECT a.id AS id, (market.market || stock.code) AS market_code, a.date, "
"a.countAsGift*0.0001 AS countAsGift, a.countForSell*0.0001 AS countForSell, "
"a.priceForSell*0.001 AS priceForSell, a.bonus*0.001,a.countOfIncreasement*0.0001 AS "
"countOfIncreasement, a.totalCount AS totalCount, a.freeCount AS freeCount FROM "
"stkweight AS a, stock, market WHERE a.stockid=stock.stockid AND "
"market.marketid=stock.marketid ORDER BY a.stockid, a.date");

for (const auto& record : view) {
auto iter = result.find(record.market_code);
if (iter == result.end()) {
auto in_iter = result.insert(std::make_pair(record.market_code, StockWeightList()));
if (in_iter.second) {
iter = in_iter.first;
}
}
iter->second.emplace_back(StockWeight(
Datetime(record.date), record.countAsGift, record.countForSell, record.priceForSell,
record.bonus, record.countOfIncreasement, record.totalCount, record.freeCount));
}

} catch (std::exception& e) {
HKU_FATAL("load StockWeight table failed! {}", e.what());
return result;
} catch (...) {
HKU_FATAL("load StockWeight table failed!");
return result;
}

return result;
}

Parameter SQLiteBaseInfoDriver ::getFinanceInfo(const string& market, const string& code) {
Parameter result;
HKU_IF_RETURN(!m_pool, result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@ class SQLiteBaseInfoDriver : public BaseInfoDriver {
virtual Parameter getFinanceInfo(const string& market, const string& code) override;
virtual StockWeightList getStockWeightList(const string& market, const string& code,
Datetime start, Datetime end) override;
virtual unordered_map<string, StockWeightList> getAllStockWeightList() override;
virtual MarketInfo getMarketInfo(const string& market) override;
virtual StockTypeInfo getStockTypeInfo(uint32_t type) override;
virtual StockInfo getStockInfo(string market, const string& code) override;
virtual vector<StockInfo> getAllStockInfo() override;
virtual std::unordered_set<Datetime> getAllHolidays() override;

private:
//股票基本信息数据库实例
// 股票基本信息数据库实例
ConnectPool<SQLiteConnect>* m_pool;
};

Expand Down
14 changes: 14 additions & 0 deletions hikyuu_cpp/hikyuu/data_driver/base_info/table/StockWeightTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ class StockWeightTable {
double freeCount{0.};
};

struct StockWeightTableView {
TABLE_BIND9(StockWeightTableView, stkweight, market_code, date, countAsGift, countForSell,
priceForSell, bonus, countOfIncreasement, totalCount, freeCount)
string market_code;
uint64_t date{0};
double countAsGift{0.};
double countForSell{0.};
double priceForSell{0.};
double bonus{0.};
double countOfIncreasement{0.};
double totalCount{0.};
double freeCount{0.};
};

} // namespace hku

#endif /* HIKYUU_DATA_DRIVER_BASE_INFO_TABLE_STOCKWEIGHTTABLE_H */
36 changes: 36 additions & 0 deletions hikyuu_cpp/hikyuu/utilities/db_connect/DBConnectBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@ class HKU_API DBConnectBase : public std::enable_shared_from_this<DBConnectBase>
template <typename T>
void load(T &item, const DBCondition &cond);

/**
* 加载模型数据至指定的模型实例, 仅供查询
* @param item 指定的模型实例
* @param sql 查询条件 select 的 sql 语句
*/
template <typename T>
void loadView(T &item, const std::string &sql);

/**
* 批量加载模型数据至容器(vector,list 等支持 push_back 的容器)
* @param container 指定容器
Expand All @@ -155,6 +163,14 @@ class HKU_API DBConnectBase : public std::enable_shared_from_this<DBConnectBase>
template <typename Container>
void batchLoad(Container &container, const DBCondition &cond);

/**
* 批量加载模型数据至容器(vector,list 等支持 push_back 的容器)
* @param container 指定容器
* @param sql select 的查询语句
*/
template <typename Container>
void batchLoadView(Container &container, const std::string &sql);

/**
* 批量更新
* @param container 拥有迭代器的容器
Expand Down Expand Up @@ -456,6 +472,26 @@ void DBConnectBase::batchLoad(Container &container, const DBCondition &cond) {
batchLoad(container, cond.str());
}

template <typename T>
void DBConnectBase::loadView(T &item, const std::string &sql) {
SQLStatementPtr st = getStatement(sql);
st->exec();
if (st->moveNext()) {
item.load(st);
}
}

template <typename Container>
void DBConnectBase::batchLoadView(Container &container, const std::string &sql) {
SQLStatementPtr st = getStatement(sql);
st->exec();
while (st->moveNext()) {
typename Container::value_type tmp;
tmp.load(st);
container.push_back(tmp);
}
}

template <class Container>
inline void DBConnectBase::batchUpdate(Container &container, bool autotrans) {
batchUpdate(container.begin(), container.end(), autotrans);
Expand Down

0 comments on commit 41866b7

Please sign in to comment.