Skip to content

Commit

Permalink
Merge pull request #778 from mkstoyanov/fix_wavelet_changegpu
Browse files Browse the repository at this point in the history
fix bug when wavelet grids change the gpu context
  • Loading branch information
mkstoyanov authored Oct 18, 2024
2 parents 513b72c + a6009ff commit 51f0e00
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
2 changes: 1 addition & 1 deletion InterfacePython/TasmanianSG.py
Original file line number Diff line number Diff line change
Expand Up @@ -2140,7 +2140,7 @@ def enableAcceleration(self, sAccelerationType, iGPUID = None):
else:
if ((iGPUID < 0) or (iGPUID >= self.getNumGPUs())):
raise TasmanianInputError("iGPUID", "ERROR: invalid GPU ID number")
pLibTSG.tsgEnableAcceleration(self.pGrid, bytes(sAccelerationType, encoding='utf8'), iGPUID)
pLibTSG.tsgEnableAccelerationGPU(self.pGrid, bytes(sAccelerationType, encoding='utf8'), iGPUID)

def getAccelerationType(self):
'''
Expand Down
3 changes: 3 additions & 0 deletions SparseGrids/tsgGridWavelet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ void GridWavelet::buildInterpolationMatrix() const{

if (order == 1 and TasSparse::WaveletBasisMatrix::useDense(acceleration, num_points)
and acceleration->useKernels()){ // using the GPU algorithm
acceleration->setDevice();
std::vector<double> pnts(Utils::size_mult(num_dimensions, num_points));
getPoints(pnts.data());
GpuVector<double> gpu_pnts(acceleration, pnts);
Expand Down Expand Up @@ -1040,6 +1041,8 @@ void GridWavelet::updateAccelerationData(AccelerationContext::ChangeType change)
case AccelerationContext::change_gpu_device:
gpu_cache.reset();
gpu_cachef.reset();
if (inter_matrix.getNumRows() > 0)
inter_matrix = TasSparse::WaveletBasisMatrix();
break;
case AccelerationContext::change_sparse_dense:
if ((acceleration->algorithm_select == AccelerationContext::algorithm_dense and inter_matrix.isSparse())
Expand Down

0 comments on commit 51f0e00

Please sign in to comment.