Skip to content

Commit

Permalink
Nn descent (#233)
Browse files Browse the repository at this point in the history
  • Loading branch information
msluszniak authored Mar 5, 2024
1 parent 4dccc0a commit ce39654
Show file tree
Hide file tree
Showing 6 changed files with 1,317 additions and 11 deletions.
13 changes: 10 additions & 3 deletions lib/scholar/neighbors/large_vis.ex
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,23 @@ defmodule Scholar.Neighbors.LargeVis do
num_trees = opts[:num_trees] || 5 + round(:math.pow(size, 0.25))
key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(System.system_time()) end)

fit_n(tensor, num_neighbors: k, min_leaf_size: min_leaf_size, num_trees: num_trees, key: key)
fit_n(
tensor,
key,
num_neighbors: k,
min_leaf_size: min_leaf_size,
num_trees: num_trees,
num_iters: opts[:num_iters]
)
end

defnp fit_n(tensor, opts) do
defnp fit_n(tensor, key, opts) do
forest =
Forest.fit(tensor,
num_neighbors: opts[:num_neighbors],
min_leaf_size: opts[:min_leaf_size],
num_trees: opts[:num_trees],
key: opts[:key]
key: key
)

{graph, _} = Forest.predict(forest, tensor)
Expand Down
Loading

0 comments on commit ce39654

Please sign in to comment.