-
Notifications
You must be signed in to change notification settings - Fork 136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ASA core modules #641
base: develop
Are you sure you want to change the base?
ASA core modules #641
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Partway through src
src/sunadjointcheckpointscheme/fixed/sunadjointcheckpointscheme_fixed.c
Outdated
Show resolved
Hide resolved
src/sunadjointcheckpointscheme/fixed/sunadjointcheckpointscheme_fixed.c
Outdated
Show resolved
Hide resolved
src/sunadjointcheckpointscheme/fixed/sunadjointcheckpointscheme_fixed.c
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Continuing through src
|
||
if (!(step_num % IMPL_MEMBER(self, interval))) | ||
{ | ||
if (stage_num == 0) { *yes_or_no = SUNTRUE; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will be z[0]
rather than y_n
for methods with an implicit first stage.
{ | ||
SUNFunctionBegin(self->sunctx); | ||
|
||
void* queue = NULL; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here and below, I think a NULL
queue will cause an error with SYCL
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm. Im not sure how to get access to a queue here. I suppose we could check if the vector is a SYCL vector, and if it is, get it from the vector?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose another option is to get it from the memory helper, but it would simiarly require checking the type of SUNMemoryHelper
and adding a get function for the queue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see a good way around this at the moment. The memory helper does not store the queue (it expects the queue to be provided as an input to the memory helper functions). We could add a "get execution queue" function to the NVector but not all the functions that call a memory helper function have access to a vector. Looking back we maybe should have added the queue (stream) as an input the memory helper constructor instead of having it in the function signatures.
SUNMemoryType buffer_mem_type = N_VGetDeviceArrayPointer(v) ? SUNMEMTYPE_DEVICE | ||
: SUNMEMTYPE_HOST; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here and below, user supplied vectors my not supply N_VGetDeviceArrayPointer
but still use device data. I think we'll need a new NVector function to query what memory space the data is stored in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree the ideal long term solution is to add a function to query the memory space. However, an alternative is that we just require N_VGetDeviceArrayPointer
to be provided when using this module. We already do require it when using a GPU-enabled direct SUNLinearSolver
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GPU-enabled direct SUNLinearSolvers
require this because they need the underlying data pointer. In this case that's not required and I could see this a being potentially problematic because it might give the impression that user need to provide a working GetDeviceArrayPointer
implementation or we ask them to make a dummy version that does not work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Finished pass over src
, starting on test
return SUN_SUCCESS; | ||
} | ||
|
||
SUNErrCode SUNAdjointStepper_ReInit(SUNAdjointStepper self, N_Vector y0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[discussion] The current behavior is like ReInit
, but we do not have a SUNStepperReInit
function. So the user would need to reinitialize the forward and adjoint steppers. We could add SUNStepperReInit
with the same signature of SUNStepperReset
. This would require some updates to the ARKODE ReInit
functions that would be wrapped by a generic ReInit
to allow for NULL
inputs in order to retain the current RHS functions (and some other changes handle NULL
inputs).
|
||
while ((direction == -one && t > tout) || (direction == one && t < tout)) | ||
{ | ||
SUNCheckCall(SUNAdjointStepper_OneStep(self, tout, sens, tret)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To get interpolated output at the desired time this should use SUNStepper_Evolve
(or we'll need to add a SUNStepper
function to get interpolated output)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, indeed this should have been calling SUNStepper_Evolve
.
self->last_flag = adj_sunstepper->last_flag; | ||
|
||
self->step_idx--; | ||
self->nst++; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to store nst
or can we get this from the adj_sunstepper
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The SUNStepper
does not have a nst
counter nor a method to get it from the underlying integrator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should add a Get
function to the SUNStepper
since this counter won't get incremented if using SUNAdjointStepper_Evolve
.
Co-authored-by: David Gardner <gardner48@llnl.gov>
… into feature/sunadjoint-core-modules
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I accidentally started a 'review' when responding to comments.
|
||
Function pointer to insert a checkpoint state represented as a :c:type:`N_Vector`. | ||
|
||
.. c:member:: SUNErrCode (*loadVector)(SUNAdjointCheckpointScheme cs, int64_t step_num, int64_t stage_num, sunbooleantype peek, N_Vector* yout, sunrealtype* tout) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we could check that within the checkpoint scheme since we do not know the time step size. We could probably check that in the package-specific code though.
This PR adds the core modules that will support adjoint sensitivity analysis in a package-agnostic way.