Skip to content

Commit

Permalink
Formatted description, added choice of KNN algorithm, Small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
norm4nn committed Aug 15, 2024
1 parent 43f8831 commit 5c7812e
Showing 1 changed file with 74 additions and 24 deletions.
98 changes: 74 additions & 24 deletions lib/scholar/cluster/optics.ex
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
defmodule Scholar.Cluster.OPTICS do
@moduledoc """
OPTICS (Ordering Points To Identify the Clustering Structure), closely related to DBSCAN, finds core sample of high density and expands clusters from them. Unlike DBSCAN, keeps cluster hierarchy for a variable neighborhood radius. Clusters are then extracted using a DBSCAN-like method.
OPTICS (Ordering Points To Identify the Clustering Structure) is an algorithm
for finding density-based clusters in spatial data. It is closely related
to DBSCAN, finds core sample of high density and expands clusters from them.
Unlike DBSCAN, keeps cluster hierarchy for a variable neighborhood radius.
Clusters are then extracted using a DBSCAN-like method.
"""
import Nx.Defn
require Nx
Expand All @@ -9,37 +13,59 @@ defmodule Scholar.Cluster.OPTICS do
min_samples: [
default: 5,
type: :pos_integer,
doc: "The number of samples in a neighborhood for a point to be considered as a core point."
doc: """
The number of samples in a neighborhood for a point to be considered as a core point.
"""
],
max_eps: [
default: Nx.Constants.infinity(),
type: {:custom, Scholar.Options, :beta, []},
doc:
"The maximum distance between two samples for one to be considered as in the neighborhood of the other. Default value of Nx.Constants.infinity() will identify clusters across all scales "
doc: """
The maximum distance between two samples for one to be considered as in the neighborhood of the other.
Default value of Nx.Constants.infinity() will identify clusters across all scales.
"""
],
eps: [
default: Nx.Constants.infinity(),
type: {:custom, Scholar.Options, :beta, []},
doc:
"The maximum distance between two samples for one to be considered as in the neighborhood of the other. By default it assumes the same value as max_eps."
doc: """
The maximum distance between two samples for one to be considered as in the neighborhood of the other.
By default it assumes the same value as max_eps.
"""
],
algorithm: [
default: :brute,
type: :atom,
doc: """
Algorithm used to compute the k-nearest neighbors. Possible values:
* `:brute` - Brute-force search. See `Scholar.Neighbors.BruteKNN` for more details.
* `:kd_tree` - k-d tree. See `Scholar.Neighbors.KDTree` for more details.
* `:random_projection_forest` - Random projection forest. See `Scholar.Neighbors.RandomProjectionForest` for more details.
* Module implementing `fit(data, opts)` and `predict(model, query)`. predict/2 must return a tuple containing indices
of k-nearest neighbors of query points as well as distances between query points and their k-nearest neighbors.
Also has to take num_neighbors as argument.
"""
]
]

@opts_schema NimbleOptions.new!(opts)

@doc """
Perform OPTICS clustering for `x` which is tensor of {n_samples, n_features} shape.
Perform OPTICS clustering for `x` which is tensor of `{n_samples, n_features} shape.
## Options
#{NimbleOptions.docs(@opts_schema)}
## Return Values
The function returns a labels tensor of shape {n_samples}
Cluster labels for each point in the dataset given to fit().
Noisy samples are labeled as -1.
The function returns a labels tensor of shape `{n_samples}`.
Cluster labels for each point in the dataset given to fit().
Noisy samples are labeled as -1.
## Examples
Expand Down Expand Up @@ -69,14 +95,46 @@ defmodule Scholar.Cluster.OPTICS do
s64[6]
[0, 0, 0, 1, 1, -1]
>
iex> Scholar.Cluster.OPTICS.fit(x, max_eps: 2, min_samples: 1, algorithm: :kd_tree, metric: {:minkowski, 1})
#Nx.Tensor<
s64[6]
[0, 1, 1, 2, 2, 3]
>
"""

deftransform fit(x, opts \\ []) do
fit_p(x, NimbleOptions.validate!(opts, @opts_schema))
if Nx.rank(x) != 2 do
raise ArgumentError,
"""
expected x to have shape {num_samples, num_features}, \
got tensor with shape: #{inspect(Nx.shape(x))}
"""
end
{opts, algorithm_opts} = Keyword.split(opts, [:min_samples, :max_eps, :eps, :algorithm])
opts = NimbleOptions.validate!(opts, @opts_schema)
algorithm_opts = Keyword.put(algorithm_opts, :num_neighbors, opts[:min_samples])

algorithm_module =
case opts[:algorithm] do
:brute ->
Scholar.Neighbors.BruteKNN

:kd_tree ->
Scholar.Neighbors.KDTree

:random_projection_forest ->
Scholar.Neighbors.RandomProjectionForest

module when is_atom(module) ->
module
end
model = algorithm_module.fit(x, algorithm_opts)
{_neighbors, distances} = algorithm_module.predict(model, x)
fit_p(x, distances, opts)
end

defnp fit_p(x, opts \\ []) do
{core_distances, reachability, _predecessor, ordering} = compute_optics_graph(x, opts)
defnp fit_p(x, core_distances, opts \\ []) do
{core_distances, reachability, _predecessor, ordering} = compute_optics_graph(x, core_distances, opts)

eps =
if opts[:eps] == Nx.Constants.infinity() do
Expand All @@ -88,15 +146,12 @@ defmodule Scholar.Cluster.OPTICS do
cluster_optics_dbscan(reachability, core_distances, ordering, eps: eps)
end

defnp compute_optics_graph(x, opts \\ []) do
min_samples = opts[:min_samples]
defnp compute_optics_graph(x, distances, opts \\ []) do
max_eps = opts[:max_eps]
n_samples = Nx.axis_size(x, 0)
reachability = Nx.broadcast(Nx.Constants.max_finite({:f, 32}), {n_samples})
predecessor = Nx.broadcast(-1, {n_samples})
neighbors = Scholar.Neighbors.BruteKNN.fit(x, num_neighbors: min_samples)
core_distances = compute_core_distances(x, neighbors, min_samples: min_samples)

core_distances = Nx.slice_along_axis(distances, opts[:min_samples] - 1, 1, axis: 1)
core_distances =
Nx.select(core_distances > max_eps, Nx.Constants.infinity(), core_distances)

Expand Down Expand Up @@ -131,11 +186,6 @@ defmodule Scholar.Cluster.OPTICS do
{core_distances, reachability, predecessor, ordering}
end

defnp compute_core_distances(x, neighbors, opts \\ []) do
{_neighbors, distances} = Scholar.Neighbors.BruteKNN.predict(neighbors, x)
Nx.slice_along_axis(distances, opts[:min_samples] - 1, 1, axis: 1)
end

defnp set_reach_dist(
core_distances,
reachability,
Expand Down

0 comments on commit 5c7812e

Please sign in to comment.