Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ASA core modules #641

Open
wants to merge 450 commits into
base: develop
Choose a base branch
from
Open

ASA core modules #641

wants to merge 450 commits into from

Conversation

balos1
Copy link
Member

@balos1 balos1 commented Jan 14, 2025

This PR adds the core modules that will support adjoint sensitivity analysis in a package-agnostic way.

Copy link
Member

@gardner48 gardner48 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partway through src

CHANGELOG.md Outdated Show resolved Hide resolved
doc/shared/figs/sunadjoint_ckpt_fixed.png Outdated Show resolved Hide resolved
doc/shared/sunadjoint/SUNAdjointCheckpointScheme.rst Outdated Show resolved Hide resolved
Copy link
Member

@gardner48 gardner48 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Continuing through src

test/unit_tests/sundials/CMakeLists.txt Outdated Show resolved Hide resolved
test/unit_tests/sundials/CMakeLists.txt Outdated Show resolved Hide resolved

if (!(step_num % IMPL_MEMBER(self, interval)))
{
if (stage_num == 0) { *yes_or_no = SUNTRUE; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be z[0] rather than y_n for methods with an implicit first stage.

{
SUNFunctionBegin(self->sunctx);

void* queue = NULL;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and below, I think a NULL queue will cause an error with SYCL

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm. Im not sure how to get access to a queue here. I suppose we could check if the vector is a SYCL vector, and if it is, get it from the vector?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose another option is to get it from the memory helper, but it would simiarly require checking the type of SUNMemoryHelper and adding a get function for the queue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a good way around this at the moment. The memory helper does not store the queue (it expects the queue to be provided as an input to the memory helper functions). We could add a "get execution queue" function to the NVector but not all the functions that call a memory helper function have access to a vector. Looking back we maybe should have added the queue (stream) as an input the memory helper constructor instead of having it in the function signatures.

Comment on lines +309 to +310
SUNMemoryType buffer_mem_type = N_VGetDeviceArrayPointer(v) ? SUNMEMTYPE_DEVICE
: SUNMEMTYPE_HOST;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and below, user supplied vectors my not supply N_VGetDeviceArrayPointer but still use device data. I think we'll need a new NVector function to query what memory space the data is stored in.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree the ideal long term solution is to add a function to query the memory space. However, an alternative is that we just require N_VGetDeviceArrayPointer to be provided when using this module. We already do require it when using a GPU-enabled direct SUNLinearSolver.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPU-enabled direct SUNLinearSolvers require this because they need the underlying data pointer. In this case that's not required and I could see this a being potentially problematic because it might give the impression that user need to provide a working GetDeviceArrayPointer implementation or we ask them to make a dummy version that does not work.

src/sundials/sundatanode/sundatanode_inmem.c Show resolved Hide resolved
src/sundials/sundatanode/sundatanode_inmem.c Show resolved Hide resolved
src/sundials/sundatanode/sundatanode_inmem.c Show resolved Hide resolved
Copy link
Member

@gardner48 gardner48 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finished pass over src, starting on test

src/sundials/sundials_adjointstepper.c Outdated Show resolved Hide resolved
return SUN_SUCCESS;
}

SUNErrCode SUNAdjointStepper_ReInit(SUNAdjointStepper self, N_Vector y0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[discussion] The current behavior is like ReInit, but we do not have a SUNStepperReInit function. So the user would need to reinitialize the forward and adjoint steppers. We could add SUNStepperReInit with the same signature of SUNStepperReset. This would require some updates to the ARKODE ReInit functions that would be wrapped by a generic ReInit to allow for NULL inputs in order to retain the current RHS functions (and some other changes handle NULL inputs).


while ((direction == -one && t > tout) || (direction == one && t < tout))
{
SUNCheckCall(SUNAdjointStepper_OneStep(self, tout, sens, tret));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To get interpolated output at the desired time this should use SUNStepper_Evolve (or we'll need to add a SUNStepper function to get interpolated output)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, indeed this should have been calling SUNStepper_Evolve.

src/sundials/sundials_adjointstepper.c Show resolved Hide resolved
self->last_flag = adj_sunstepper->last_flag;

self->step_idx--;
self->nst++;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to store nst or can we get this from the adj_sunstepper?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SUNStepper does not have a nst counter nor a method to get it from the underlying integrator.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should add a Get function to the SUNStepper since this counter won't get incremented if using SUNAdjointStepper_Evolve.

src/sundials/sundials_hashmap.c Show resolved Hide resolved
src/sundials/sundials_hashmap.c Outdated Show resolved Hide resolved
src/sundials/sundials_profiler.c Show resolved Hide resolved
src/sundials/sundials_profiler.c Outdated Show resolved Hide resolved
src/sundials/sundials_hashmap.c Show resolved Hide resolved
Copy link
Member Author

@balos1 balos1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I accidentally started a 'review' when responding to comments.


Function pointer to insert a checkpoint state represented as a :c:type:`N_Vector`.

.. c:member:: SUNErrCode (*loadVector)(SUNAdjointCheckpointScheme cs, int64_t step_num, int64_t stage_num, sunbooleantype peek, N_Vector* yout, sunrealtype* tout)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we could check that within the checkpoint scheme since we do not know the time step size. We could probably check that in the package-specific code though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants