Skip to content

Commit

Permalink
More efficient use of memory.
Browse files Browse the repository at this point in the history
  • Loading branch information
auralius committed Apr 12, 2022
1 parent 2d03d99 commit 6f151ba
Show file tree
Hide file tree
Showing 11 changed files with 218 additions and 147 deletions.
28 changes: 15 additions & 13 deletions src/ex09_dpa_piecewise_mass_spring.m
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@
clear

% Setup the states and the inputs
X = 0 : 0.025 : 5;
Y = 0 : 0.025 : 5;
Ux = 0 : 0.05 : 2;
Uy = -3 : 0.05 : 3;
X = 0 : 0.02 : 5;
Y = 1 : 0.02 : 5;
U = 0 : 0.02 : 2;
V = -2 : 0.02 : 2;

% Initiate the solver
dpf.states = {X, Y};
dpf.inputs = {Ux Uy};
dpf.inputs = {U, V};
dpf.T_ocp = 1;
dpf.T_dyn = 1;
dpf.T_dyn = 0.1;
dpf.n_horizon = 7;
dpf.state_update_fn = @state_update_fn;
dpf.stage_cost_fn = @stage_cost_fn;
Expand All @@ -31,18 +31,20 @@
% Initiate and run the solver, do forwar tracing and plot the results
dpf = yadpf_solve(dpf);
dpf = yadpf_trace(dpf, [0 5]);
yadpf_plot(dpf, '-o');
yadpf_plot(dpf, '-');

% Additional plotting
figure
plot(dpf.x_star{1}, dpf.x_star{2}, '-o', 'LineWidth', 2);
hold on;
plot(dpf.x_star{1}, dpf.x_star{2}, '-', 'LineWidth', 2);
plot(dpf.x_star_unsimulated{1}, dpf.x_star_unsimulated{2}, '*r', 'LineWidth', 2);
xlabel('x')
ylabel('y')

%% The state update function
function X = state_update_fn(X, U, ~)
X{1} = X{1} + U{1};
X{2} = X{2} + U{2};
function X = state_update_fn(X, U, dt)
X{1} = X{1} + U{1} * dt;
X{2} = X{2} + U{2} * dt;
end

%% The stage cost function
Expand All @@ -62,8 +64,8 @@
%% The terminal cost function
function J = terminal_cost_fn(X)
% Control gains for the terminal node
k1 = 1000;
k2 = 1000;
k1 = 300;
k2 = 300;

% Targetted terminal states
xf = [5 4];
Expand Down
70 changes: 37 additions & 33 deletions src/ex13_dpa_two_input_mass.m
Original file line number Diff line number Diff line change
Expand Up @@ -12,52 +12,56 @@
clc

% Setup the states and the inputs
X1 = 0 : 0.001 : 1;
X2 = 0 : 0.01 : 1;
U1 = 0 : 0.1 : 5;
U2 = -4 : 0.1 : 0;
X = 0 : 0.001 : 1;
V = 0 : 0.001 : 1;
F1 = 0 : 0.1 : 3;
F2 = -4 : 0.1 : 0;

% Setup the horizon
Tf = 1; % 1 second
dt = 0.2; % Temporal discretization step
t = 0:dt:Tf;
n_horizon = length(t);
Tf = 1; % 1 second
T_ocp = 0.1; % Temporal discretization step
t = 0 : T_ocp : Tf;

% Initiate the solver
dps = dps_2X_2U(X1, X2, U1, U2, n_horizon, @state_update_fn, @stage_cost_fn, ...
@terminal_cost_fn, dt, 0.01);
dpf.states = {X, V};
dpf.inputs = {F1, F2};
dpf.T_ocp = T_ocp;
dpf.T_dyn = 0.01; % Time step for the dynamic simulation
dpf.n_horizon = length(t);
dpf.state_update_fn = @state_update_fn;
dpf.stage_cost_fn = @stage_cost_fn;
dpf.terminal_cost_fn = @terminal_cost_fn;

% Extract meaningful results
dps = forward_trace(dps, [0 0]);
% Initiate and run the solver, do forward tracing for the given initial
% condition and plot the results
dpf = yadpf_solve(dpf);
dpf = yadpf_trace(dpf, [0 0]); % Initial state: [0 0]
yadpf_plot(dpf, '-');

% Do plotting here
plot_results(dps, '-');
% Optional: draw the reachability plot
yadpf_rplot(dpf, [0.5 0], 0.1);

%%
function [x1_next, x2_next] = state_update_fn(x1, x2, u1, u2, dt)
m = 1;

x1_next = x1 + dt*x2;
x2_next = x2 + dt/m.*(u1+u2);
function X = state_update_fn(X, F, dt)
m = 1; % Mass
b = 0.1; % Damping coefficient

X{1} = X{1} + dt*X{2};
X{2} = X{2} - b/m*dt.*X{2} + dt/m.*(F{1}+F{2});

end

%%
function J = stage_cost_fn(x1, x2, u1, u2, k, dt)
a1 = 1;
a2 = 1;
J = a1*dt*u1.^2 + a2*dt*u2.^2;
function J = stage_cost_fn(X, F, k, dt)
J = dt*F{1}.^2;
end

%%
function J = terminal_cost_fn(x1, x2)
% Weighting factors
a2 = 1000;
a3 = 1000;
%% The terminal cost function
function J = terminal_cost_fn(X)
xf = [0.5 0];

% Final states
xf = 0.5;
vf = 0;
% Control gains
r1 = 1000;
r2 = 100;

J = a2.*(x1-xf).^2 + a3.*(x2-vf).^2;
J = r1*(X{1}-xf(1)).^2 + r2*(X{2}-xf(2)).^2;
end
10 changes: 10 additions & 0 deletions src/fast_ind2sub2.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
function [r, c] = fast_ind2sub2(sz, idx)
% Similar to ind2sub with length(sz) = 2

%------------- BEGIN CODE --------------

r = rem(idx - 1, sz(1)) + 1;
c = (idx - r) / sz(1) + 1;

end
%------------- END OF CODE --------------
2 changes: 1 addition & 1 deletion src/fastsca2mat.m
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

%------------- BEGIN CODE --------------

V = s*ones(nr, nc);
V = s * ones(nr, nc);

end
%------------- END OF CODE --------------
2 changes: 1 addition & 1 deletion src/fastsub2ind2.m
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

%------------- BEGIN CODE --------------

ind = rows + (cols-1)*sizes(1);
ind = uint32(rows + (cols - 1) * sizes(1));

end
%------------- END OF CODE --------------
11 changes: 11 additions & 0 deletions src/flexible_ind2sub.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
function b = flexible_ind2sub(sz, ind)
% The output for [s1, s2, ...] = ind2sub(sz, ind) is comma separated.
% In some cases, we do not know yet the length.

%------------- BEGIN CODE --------------

b = cell(1, length(sz));
[b{:}] = ind2sub(sz, ind);

end
%------------- END OF CODE --------------
13 changes: 13 additions & 0 deletions src/flexible_sub2ind.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function ind = flexible_sub2ind(sz, varargin)
% sub2ind(sz, sub) does not work if length(sz) == 1

%------------- BEGIN CODE --------------

if length(sz) > 1
ind = uint32(sub2ind(sz, varargin{:}));
else
ind = uint32(varargin{:});
end

end
%------------- END OF CODE --------------
38 changes: 25 additions & 13 deletions src/reachability_plot_1X.m
Original file line number Diff line number Diff line change
Expand Up @@ -18,51 +18,63 @@ function reachability_plot_1X(dpf, terminal_state, terminal_tol)
figure;
hold on;

x_star = zeros(1, dpf.n_horizon);
buffer = zeros(dpf.nX, dpf.n_horizon);
% Structure is slow, so we will unload the strucutre
n_horizon = dpf.n_horizon;
nXX = dpf.nXX;
nX = dpf.nX;
states = dpf.states;
descendant_matrix = dpf.descendant_matrix;

clear dpf;

x_star = zeros(1, n_horizon);
buffer = zeros(nX, n_horizon);

% Test for all nodes at stage-1 (every possibe ICs)
fprintf('Generating the reachability plot...\n')
fprintf('Progress: ')
ll = 0;

for j = 1:dpf.nXX
fprintf(repmat('\b',1,ll));
ll = fprintf('%.1f %%',j/dpf.nXX*100);
step = max(floor(nXX/10), 1);
for j = 1 : nXX
if rem(j-1, step) == 0
fprintf(repmat('\b',1,ll));
ll = fprintf('%.1f %%',(j-1) / nXX*100);
end

id = j;

for k = 1 : dpf.n_horizon-1
x_star(k) = dpf.states{1}(id);
id = dpf.descendant_matrix(k,id);
for k = 1 : n_horizon-1
x_star(k) = states{1}(id);
id = descendant_matrix(k,id);
end

% The last stage
x_star(dpf.n_horizon) = dpf.states{1}(id);
x_star(n_horizon) = states{1}(id);

% Check the terminal stage, does it end at the desired terminal node?
if abs(dpf.states{1}(id) - terminal_state) < terminal_tol
if abs(states{1}(id) - terminal_state) < terminal_tol
buffer(j,:) = x_star; % If yes, keep them
end
end

fprintf('\nComplete!\n')

% Plot only the maximums and the minimums, color the area in between.
buffer(~any(buffer,2),:) = []; % Delete rows that are all zeros
buffer(~any(buffer, 2), :) = []; % Delete rows that are all zeros

if isempty(buffer)
error('No reachable states are foud, increase the tollerance...\n');
end

mins = min(buffer);
maxs = max(buffer);
k = 1 : dpf.n_horizon;
k = 1 : n_horizon;
plot(k, mins);
plot(k, maxs);
patch([k fliplr(k)], [mins fliplr(maxs)], 'g')

xlim([1 dpf.n_horizon+1])
xlim([1 n_horizon+1])
xlabel(['Stage-' '$k$'], 'Interpreter','latex')
ylabel('$x_1(k)$', 'Interpreter','latex')
title('Backward Reachability Plot');
Expand Down
64 changes: 39 additions & 25 deletions src/reachability_plot_2X.m
Original file line number Diff line number Diff line change
Expand Up @@ -17,54 +17,68 @@ function reachability_plot_2X(dpf, terminal_state, terminal_tol)
figure;
hold on;

x1 = zeros(1, dpf.n_horizon);
x2 = zeros(1, dpf.n_horizon);
% Structure is slow, so we will unload the strucutre
n_horizon = dpf.n_horizon;
nXX = dpf.nXX;
nX = dpf.nX;
states = dpf.states;
descendant_matrix = dpf.descendant_matrix;
x_star_unsimulated = dpf.x_star_unsimulated;

x1s = zeros(dpf.nXX, dpf.n_horizon);
x2s = zeros(dpf.nXX, dpf.n_horizon);
clear dpf;

x1 = zeros(1, n_horizon);
x2 = zeros(1, n_horizon);

x1s = zeros(nXX, n_horizon);
x2s = zeros(nXX, n_horizon);

% Test for all nodes at stage-1 (every possibe ICs)
fprintf('Generating the reachability plot...\n')
fprintf('Progress: ')
ll = 0;

for j = 1:dpf.nXX
fprintf(repmat('\b',1,ll));
ll = fprintf('%.1f %%',j/dpf.nXX*100);
step = max(floor(nXX/10), 1);
for j = 1 : nXX
if (rem(j-1, step) == 0)
fprintf(repmat('\b', 1, ll));
ll = fprintf('%.1f %%',(j-1) / nXX * 100);
end

id = j;

for k = 1 : dpf.n_horizon-1
[r,c] = ind2sub([dpf.nX(1) dpf.nX(2)], id);
x1(k) = dpf.states{1}(r);
x2(k) = dpf.states{2}(c);
id = dpf.descendant_matrix(k,id);
for k = 1 : n_horizon-1
[r, c] = fast_ind2sub2([nX(1) nX(2)], id);
x1(k) = states{1}(r);
x2(k) = states{2}(c);
id = descendant_matrix(k, id);
end

% The last stage
[r,c] = ind2sub([dpf.nX(1) dpf.nX(2)], id);
x1(dpf.n_horizon) = dpf.states{1}(r);
x2(dpf.n_horizon) = dpf.states{2}(c);
[r, c] = fast_ind2sub2([nX(1) nX(2)], id);
x1(n_horizon) = states{1}(r);
x2(n_horizon) = states{2}(c);

% Check the terminal stage, does it end at the desired terminal node?
x3s = [dpf.states{1}(r) dpf.states{2}(c)];
x3s = [states{1}(r) states{2}(c)];
if norm(x3s - terminal_state) < terminal_tol
x1s(j,:) = x1; % If yes, keep them
x2s(j,:) = x2; % If yes, keep them
end
end
clear x1 x2;

clear x1 x2 descendant_matrix;

fprintf('\nComplete!\n')

% Delete rows that all zeros
a = any(x1s+x2s, 2);
a = any(x1s + x2s, 2);
x1s(~a,:) = [];
x2s(~a,:) = [];

% Convert to 3D pointclouds
k = repmat(1:dpf.n_horizon, size(x1s,1),1);
X = [reshape(x1s,[],1) reshape(x2s,[],1) reshape(k,[],1)];
k = repmat(1 : n_horizon, size(x1s,1),1);
X = [reshape(x1s, [], 1) reshape(x2s, [], 1) reshape(k, [], 1)];
X = unique(X,'rows');
clear a x1s x2s;

Expand All @@ -75,12 +89,12 @@ function reachability_plot_2X(dpf, terminal_state, terminal_tol)
% Scatter plot the pointclouds
fscatter3(X(:,1), X(:,2), X(:,3), X(:,3), jet);
hold on
plot3(dpf.x_star_unsimulated{1}, dpf.x_star_unsimulated{2}, ...
1:dpf.n_horizon, 'k', 'LineWidth', 3);
plot3(x_star_unsimulated{1}, x_star_unsimulated{2}, ...
1:n_horizon, 'k', 'LineWidth', 3);

% Make it beautiful
xlim([min(dpf.states{1}) max(dpf.states{1})]);
ylim([min(dpf.states{2}) max(dpf.states{2})]);
xlim([min(states{1}) max(states{1})]);
ylim([min(states{2}) max(states{2})]);

zlabel(['Stage-' '$k$'], 'Interpreter','latex')
xlabel('$x_1(k)$', 'Interpreter','latex')
Expand All @@ -91,6 +105,6 @@ function reachability_plot_2X(dpf, terminal_state, terminal_tol)

ax = gca;
ax.SortMethod = 'childorder';
ax.FontName = 'times';
ax.FontName = 'Times';
end
%------------- END OF CODE --------------
Loading

0 comments on commit 6f151ba

Please sign in to comment.