Skip to content

Commit

Permalink
ROMHandler::SaveOperator.
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamer2368 committed Jan 12, 2024
1 parent 491908a commit 586bfca
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 40 deletions.
2 changes: 2 additions & 0 deletions include/multiblock_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ friend class ParameterizedProblem;
virtual void AssembleInterfaceMatrixes() = 0;

// Global ROM operator Loading.
virtual void SaveROMOperator(const std::string input_prefix="")
{ rom_handler->SaveOperator(input_prefix); }
virtual void LoadROMOperatorFromFile(const std::string input_prefix="")
{ rom_handler->LoadOperatorFromFile(input_prefix); }

Expand Down
4 changes: 3 additions & 1 deletion include/rom_handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class ROMHandlerBase
const TrainMode GetTrainMode() { return train_mode; }
const int GetNumROMRefBlocks() { return num_rom_ref_blocks; }
const int GetComponentNumBasis(const int &basis_idx) { return num_ref_basis[basis_idx]; }
const ROMBuildingLevel SaveOperator() { return save_operator; }
const ROMBuildingLevel GetBuildingLevel() { return save_operator; }
const bool BasisLoaded() { return basis_loaded; }
const bool OperatorLoaded() { return operator_loaded; }
const std::string GetOperatorPrefix() { return operator_prefix; }
Expand Down Expand Up @@ -155,6 +155,7 @@ class ROMHandlerBase
virtual void Solve(BlockVector* U) = 0;
virtual void NonlinearSolve(Operator &oper, BlockVector* U, Solver *prec=NULL) = 0;

virtual void SaveOperator(const std::string input_prefix="") = 0;
virtual void LoadOperatorFromFile(const std::string input_prefix="") = 0;
virtual void LoadOperator(BlockMatrix *input_mat) = 0;

Expand Down Expand Up @@ -230,6 +231,7 @@ class MFEMROMHandler : public ROMHandlerBase
virtual void Solve(BlockVector* U);
virtual void NonlinearSolve(Operator &oper, BlockVector* U, Solver *prec=NULL) override;

virtual void SaveOperator(const std::string input_prefix="");
virtual void LoadOperatorFromFile(const std::string input_prefix="");
virtual void LoadOperator(BlockMatrix *input_mat);

Expand Down
1 change: 1 addition & 0 deletions include/steady_ns_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ friend class SteadyNSOperator;

virtual void Assemble();

virtual void SaveROMOperator(const std::string input_prefix="");
virtual void LoadROMOperatorFromFile(const std::string input_prefix="");

// Component-wise assembly
Expand Down
5 changes: 3 additions & 2 deletions src/main_workflow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ void BuildROM(MPI_Comm comm)
rom->LoadReducedBasis();

TopologyHandlerMode topol_mode = test->GetTopologyMode();
ROMBuildingLevel save_operator = rom->SaveOperator();
ROMBuildingLevel save_operator = rom->GetBuildingLevel();

// NOTE(kevin): global operator required only for global rom operator.
if (save_operator == ROMBuildingLevel::GLOBAL)
Expand All @@ -335,6 +335,7 @@ void BuildROM(MPI_Comm comm)
case ROMBuildingLevel::GLOBAL:
{
test->ProjectOperatorOnReducedBasis();
test->SaveROMOperator();
break;
}
case ROMBuildingLevel::NONE:
Expand Down Expand Up @@ -393,7 +394,7 @@ double SingleRun(MPI_Comm comm, const std::string output_file)
if (test->UseRom())
{
printf("ROM with ");
ROMBuildingLevel save_operator = rom->SaveOperator();
ROMBuildingLevel save_operator = rom->GetBuildingLevel();
TopologyHandlerMode topol_mode = test->GetTopologyMode();

if (topol_mode == TopologyHandlerMode::SUBMESH)
Expand Down
31 changes: 20 additions & 11 deletions src/rom_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,20 +385,29 @@ void MFEMROMHandler::ProjectOperatorOnReducedBasis(const Array2D<Operator*> &mat
romMat_mono = romMat->CreateMonolithic();

if (linsol_type == SolverType::DIRECT) SetupDirectSolver();
}

if (save_operator == ROMBuildingLevel::GLOBAL)
{
std::string filename = operator_prefix + ".h5";
hid_t file_id;
herr_t errf = 0;
file_id = H5Fcreate(filename.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT, H5P_DEFAULT);
assert(file_id >= 0);
void MFEMROMHandler::SaveOperator(const std::string input_prefix)
{
assert(save_operator == ROMBuildingLevel::GLOBAL);
assert(operator_loaded && romMat);

hdf5_utils::WriteBlockMatrix(file_id, "ROM_matrix", romMat);
std::string filename;
if (input_prefix == "")
filename = operator_prefix;
else
filename = input_prefix;
filename += ".h5";

errf = H5Fclose(file_id);
assert(errf >= 0);
}
hid_t file_id;
herr_t errf = 0;
file_id = H5Fcreate(filename.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT, H5P_DEFAULT);
assert(file_id >= 0);

hdf5_utils::WriteBlockMatrix(file_id, "ROM_matrix", romMat);

errf = H5Fclose(file_id);
assert(errf >= 0);
}

void MFEMROMHandler::ProjectVectorOnReducedBasis(const BlockVector* vec, BlockVector*& rom_vec)
Expand Down
54 changes: 28 additions & 26 deletions src/steady_ns_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ SteadyNSSolver::~SteadyNSSolver()
if (use_rom)
{
DeletePointers(comp_tensors);
if (rom_handler->SaveOperator() != ROMBuildingLevel::COMPONENT)
if (rom_handler->GetBuildingLevel() != ROMBuildingLevel::COMPONENT)
DeletePointers(subdomain_tensors);
}
}
Expand Down Expand Up @@ -362,9 +362,35 @@ void SteadyNSSolver::Assemble()
// nonlinear operator?
}

void SteadyNSSolver::SaveROMOperator(const std::string input_prefix)
{
MultiBlockSolver::SaveROMOperator(input_prefix);

std::string filename = rom_handler->GetOperatorPrefix() + ".h5";
assert(FileExists(filename));

hid_t file_id;
herr_t errf = 0;
file_id = H5Fopen(filename.c_str(), H5F_ACC_RDWR, H5P_DEFAULT);
assert(file_id >= 0);

hid_t grp_id;
grp_id = H5Gcreate(file_id, "ROM_tensors", H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT);
assert(grp_id >= 0);

for (int m = 0; m < numSub; m++)
hdf5_utils::WriteDataset(grp_id, "subdomain" + std::to_string(m), *subdomain_tensors[m]);

errf = H5Gclose(grp_id);
assert(errf >= 0);

errf = H5Fclose(file_id);
assert(errf >= 0);
}

void SteadyNSSolver::LoadROMOperatorFromFile(const std::string input_prefix)
{
assert(rom_handler->SaveOperator() == ROMBuildingLevel::GLOBAL);
assert(rom_handler->GetBuildingLevel() == ROMBuildingLevel::GLOBAL);

rom_handler->LoadOperatorFromFile(input_prefix);

Expand Down Expand Up @@ -581,30 +607,6 @@ void SteadyNSSolver::ProjectOperatorOnReducedBasis()
rom_handler->GetBasisOnSubdomain(m, basis);
subdomain_tensors[m] = GetReducedTensor(basis, ufes[m]);
}

if (rom_handler->SaveOperator() == ROMBuildingLevel::GLOBAL)
{
std::string filename = rom_handler->GetOperatorPrefix() + ".h5";
assert(FileExists(filename));

hid_t file_id;
herr_t errf = 0;
file_id = H5Fopen(filename.c_str(), H5F_ACC_RDWR, H5P_DEFAULT);
assert(file_id >= 0);

hid_t grp_id;
grp_id = H5Gcreate(file_id, "ROM_tensors", H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT);
assert(grp_id >= 0);

for (int m = 0; m < numSub; m++)
hdf5_utils::WriteDataset(grp_id, "subdomain" + std::to_string(m), *subdomain_tensors[m]);

errf = H5Gclose(grp_id);
assert(errf >= 0);

errf = H5Fclose(file_id);
assert(errf >= 0);
}
}

void SteadyNSSolver::SolveROM()
Expand Down

0 comments on commit 586bfca

Please sign in to comment.