Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Major speedup to simulations #237

Merged
merged 2 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 48 additions & 38 deletions Circuit/Simulation/Simulation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -444,32 +444,24 @@ private static void Solve(CodeGen code, LinqExpr Ab, IEnumerable<LinearCombinati
LinqExpr.Constant(0.0)));
}

// Gaussian elimination on this turd.
// Fully solve this system of equations.
code.Add(LinqExpr.Call(
GetMethod<Simulation>(Vector.IsHardwareAccelerated ? nameof(RowReduceVector) : nameof(RowReduce), Ab.Type, typeof(int), typeof(int)),
GetMethod<Simulation>(Vector.IsHardwareAccelerated ? nameof(SolveVector) : nameof(Solve), Ab.Type, typeof(int), typeof(int)),
Ab,
LinqExpr.Constant(M),
LinqExpr.Constant(N)));
LinqExpr.Constant(N + 1)));

// Ab is now upper triangular, solve it.
for (int j = N - 1; j >= 0; --j)
{
LinqExpr _j = LinqExpr.Constant(j);
LinqExpr Abj = code.ReDeclInit<double[]>("Abj", LinqExpr.ArrayAccess(Ab, _j));

LinqExpr r = LinqExpr.ArrayAccess(Abj, LinqExpr.Constant(N));
for (int ji = j + 1; ji < N; ++ji)
r = LinqExpr.Add(r, LinqExpr.Multiply(LinqExpr.ArrayAccess(Abj, LinqExpr.Constant(ji)), code[deltas[ji]]));
code.DeclInit(deltas[j], LinqExpr.Divide(LinqExpr.Negate(r), LinqExpr.ArrayAccess(Abj, _j)));
}
// Extract the solutions.
for (int j = 0; j < N; ++j)
code.DeclInit(deltas[j], LinqExpr.Negate(LinqExpr.ArrayAccess(LinqExpr.ArrayAccess(Ab, LinqExpr.Constant(j)), LinqExpr.Constant(N))));
}

// A human readable implementation of RowReduce.
public static void RowReduce(double[][] Ab, int M, int N)
public static void Solve(double[][] Ab, int M, int N)
{
// Solve for dx.
// For each variable in the system...
for (int j = 0; j + 1 < N; ++j)
// For each column...
for (int j = 0; j < Math.Min(M, N); ++j)
{
int pi = j;
double max = Math.Abs(Ab[j][j]);
Expand Down Expand Up @@ -497,25 +489,35 @@ public static void RowReduce(double[][] Ab, int M, int N)

double[] Abj = Ab[j];

// Eliminate the rows after the pivot.
// Eliminate all other rows.
double p = Abj[j];
for (int i = j + 1; i < M; ++i)
if (p == 0) continue;
for (int i = 0; i < M; ++i)
{
if (i == j) continue;
double[] Abi = Ab[i];
if (Abi[j] == 0.0) continue;

double s = Abi[j] / p;
if (s != 0.0)
for (int ij = j + 1; ij <= N; ++ij)
Abi[ij] -= Abj[ij] * s;
for (int ij = j + 1; ij < N; ++ij)
Abi[ij] -= Abj[ij] * s;
}

// Scale the pivot row, so the pivot is one.
double inv_p = 1.0 / p;
for (int ij = j + 1; ij < N; ++ij)
Abj[ij] *= inv_p;
}
}

//This algorith has no tail-loop - it requires arrays to be padded to N + Vector.Count - 1
private static void RowReduceVector(double[][] Ab, int M, int N)
private static void SolveVector(double[][] Ab, int M, int N)
{
var vectorLength = Vector<double>.Count;

// Solve for dx.
// For each variable in the system...
for (int j = 0; j + 1 < N; ++j)
for (int j = 0; j < Math.Min(M, N); ++j)
{
int pi = j;
double max = Math.Abs(Ab[j][j]);
Expand All @@ -540,24 +542,32 @@ private static void RowReduceVector(double[][] Ab, int M, int N)
Ab[j] = tmp;
}

var vectorLength = Vector<double>.Count;
// Eliminate the rows after the pivot.
double p = Ab[j][j];
for (int i = j + 1; i < M; ++i)
double[] Abj = Ab[j];

// Eliminate all other rows.
double p = Abj[j];
if (p == 0) continue;
for (int i = 0; i < M; ++i)
{
double s = Ab[i][j] / p;
if (s != 0.0)
if (i == j) continue;
double[] Abi = Ab[i];
if (Abi[j] == 0) continue;

double s = Abi[j] / p;
for (int ij = j + 1; ij < N; ij += vectorLength)
{
int jj;
for (jj = j + 1; jj <= N; jj += vectorLength)
{
var source = new Vector<double>(Ab[j], jj);
var target = new Vector<double>(Ab[i], jj);
var res = target - (source * s);
res.CopyTo(Ab[i], jj);
}
var source = new Vector<double>(Abj, ij);
var target = new Vector<double>(Abi, ij);
var res = target - (source * s);
res.CopyTo(Abi, ij);
}
}

// Scale the pivot row, so the pivot is one.
// TODO: Vectorize
double inv_p = 1.0 / p;
for (int ij = j + 1; ij < N; ++ij)
Abj[ij] *= inv_p;
}
}

Expand Down
3 changes: 3 additions & 0 deletions Circuit/Simulation/TransientSolution.cs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ public static TransientSolution Solve(Analysis Analysis, Expression TimeStep, IE
solutions.Count,
solutions.Sum(i => i.Unknowns.Count()));

// Solutions from `Solve` might depend on previous solutions, so we need to make sure to emit the solutions in the order that satisifies such dependencies.
solutions.Reverse();

return new TransientSolution(
h,
solutions,
Expand Down
2 changes: 1 addition & 1 deletion ComputerAlgebra