Skip to content

Commit

Permalink
Deployed 95e809d with MkDocs version: 1.6.0
Browse files Browse the repository at this point in the history
  • Loading branch information
mrapplexz committed Jun 5, 2024
1 parent 8def270 commit dafc74f
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 26 deletions.
49 changes: 26 additions & 23 deletions index.html
Original file line number Diff line number Diff line change
Expand Up @@ -700,11 +700,13 @@ <h4 id="31-prepare">3.1. Prepare</h4>
</span><span id="__span-0-4"><a id="__codelineno-0-4" name="__codelineno-0-4" href="#__codelineno-0-4"></a><span class="n">enable_tf32</span><span class="p">()</span>
</span></code></pre></div>
<h4 id="32-implement-your-trainable">3.2. Implement your Trainable</h4>
<p>You need to implement your custom training logic by subclassing a <code>XZTrainable</code> class.</p>
<p>Trainable is used for:
* forward pass code, including loss computation;
* specifying what metrics you want to calculate while training (xztrainer uses torchmetrics for updating and calculating the metrics since it supports distributed metric computation out of the box, see <a href="https://lightning.ai/docs/torchmetrics/stable/">torchmetrics</a> docs);
* specifying some other callbacks, such as model loading callback or logging callback</p>
<p>You need to implement your custom training logic by inheriting from a <code>XZTrainable</code> class.</p>
<p>Trainable is used for:</p>
<ul>
<li>forward pass code, including loss computation;</li>
<li>specifying what metrics you want to calculate while training (xztrainer uses torchmetrics for updating and calculating the metrics since it supports distributed metric computation out of the box, see <a href="https://lightning.ai/docs/torchmetrics/stable/">torchmetrics</a> docs);</li>
<li>specifying some other callbacks, such as model loading callback or logging callback</li>
</ul>
<p>Use <a href="trainable/">xztrainer trainable docs</a> to see full list of functions you can implement in your Trainable.</p>
<p>An example that uses cross-entropy loss for an image classification model, calculating accuracy as a metric:</p>
<div class="language-python highlight"><pre><span></span><code><span id="__span-1-1"><a id="__codelineno-1-1" name="__codelineno-1-1" href="#__codelineno-1-1"></a><span class="kn">from</span> <span class="nn">xztrainer</span> <span class="kn">import</span> <span class="n">XZTrainable</span><span class="p">,</span> <span class="n">BaseContext</span><span class="p">,</span> <span class="n">DataType</span><span class="p">,</span> <span class="n">ContextType</span><span class="p">,</span> <span class="n">ModelOutputType</span>
Expand Down Expand Up @@ -740,23 +742,24 @@ <h4 id="33-create-standard-pytorch-objects">3.3. Create standard PyTorch objects
<p>You need to implement your standard PyTorch objects related to working with data. And, for sure, your model object.</p>
<h5 id="dataset">Dataset</h5>
<p>An example <a href="https://pytorch.org/docs/stable/data.html#dataset-types">dataset</a> that remaps standard torchvision CIFAR10 dataset from tuple-yielding to dictionary-yielding just for convenience.</p>
<div class="language-python highlight"><pre><span></span><code><span id="__span-2-1"><a id="__codelineno-2-1" name="__codelineno-2-1" href="#__codelineno-2-1"></a><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">Dataset</span>
</span><span id="__span-2-2"><a id="__codelineno-2-2" name="__codelineno-2-2" href="#__codelineno-2-2"></a><span class="kn">from</span> <span class="nn">torchvision.datasets.cifar</span> <span class="kn">import</span> <span class="n">CIFAR10</span>
</span><span id="__span-2-3"><a id="__codelineno-2-3" name="__codelineno-2-3" href="#__codelineno-2-3"></a><span class="kn">from</span> <span class="nn">torchvision.transforms</span> <span class="kn">import</span> <span class="n">ToTensor</span>
</span><span id="__span-2-4"><a id="__codelineno-2-4" name="__codelineno-2-4" href="#__codelineno-2-4"></a>
</span><span id="__span-2-5"><a id="__codelineno-2-5" name="__codelineno-2-5" href="#__codelineno-2-5"></a><span class="k">class</span> <span class="nc">CifarDictDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span>
</span><span id="__span-2-6"><a id="__codelineno-2-6" name="__codelineno-2-6" href="#__codelineno-2-6"></a> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">train</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span>
</span><span id="__span-2-7"><a id="__codelineno-2-7" name="__codelineno-2-7" href="#__codelineno-2-7"></a> <span class="bp">self</span><span class="o">.</span><span class="n">base_data</span> <span class="o">=</span> <span class="n">CIFAR10</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="s1">&#39;./cifar10&#39;</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="n">train</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">ToTensor</span><span class="p">())</span>
</span><span id="__span-2-8"><a id="__codelineno-2-8" name="__codelineno-2-8" href="#__codelineno-2-8"></a>
</span><span id="__span-2-9"><a id="__codelineno-2-9" name="__codelineno-2-9" href="#__codelineno-2-9"></a> <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">item</span><span class="p">):</span>
</span><span id="__span-2-10"><a id="__codelineno-2-10" name="__codelineno-2-10" href="#__codelineno-2-10"></a> <span class="n">image</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">base_data</span><span class="p">[</span><span class="n">item</span><span class="p">]</span>
</span><span id="__span-2-11"><a id="__codelineno-2-11" name="__codelineno-2-11" href="#__codelineno-2-11"></a> <span class="k">return</span> <span class="p">{</span>
</span><span id="__span-2-12"><a id="__codelineno-2-12" name="__codelineno-2-12" href="#__codelineno-2-12"></a> <span class="s1">&#39;image&#39;</span><span class="p">:</span> <span class="n">image</span><span class="p">,</span>
</span><span id="__span-2-13"><a id="__codelineno-2-13" name="__codelineno-2-13" href="#__codelineno-2-13"></a> <span class="s1">&#39;label&#39;</span><span class="p">:</span> <span class="n">label</span>
</span><span id="__span-2-14"><a id="__codelineno-2-14" name="__codelineno-2-14" href="#__codelineno-2-14"></a> <span class="p">}</span>
</span><span id="__span-2-15"><a id="__codelineno-2-15" name="__codelineno-2-15" href="#__codelineno-2-15"></a>
</span><span id="__span-2-16"><a id="__codelineno-2-16" name="__codelineno-2-16" href="#__codelineno-2-16"></a> <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
</span><span id="__span-2-17"><a id="__codelineno-2-17" name="__codelineno-2-17" href="#__codelineno-2-17"></a> <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">base_data</span><span class="p">)</span>
<div class="language-python highlight"><pre><span></span><code><span id="__span-2-1"><a id="__codelineno-2-1" name="__codelineno-2-1" href="#__codelineno-2-1"></a><span class="kn">import</span> <span class="nn">torch</span>
</span><span id="__span-2-2"><a id="__codelineno-2-2" name="__codelineno-2-2" href="#__codelineno-2-2"></a><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">Dataset</span>
</span><span id="__span-2-3"><a id="__codelineno-2-3" name="__codelineno-2-3" href="#__codelineno-2-3"></a><span class="kn">from</span> <span class="nn">torchvision.datasets.cifar</span> <span class="kn">import</span> <span class="n">CIFAR10</span>
</span><span id="__span-2-4"><a id="__codelineno-2-4" name="__codelineno-2-4" href="#__codelineno-2-4"></a><span class="kn">from</span> <span class="nn">torchvision.transforms</span> <span class="kn">import</span> <span class="n">ToTensor</span>
</span><span id="__span-2-5"><a id="__codelineno-2-5" name="__codelineno-2-5" href="#__codelineno-2-5"></a>
</span><span id="__span-2-6"><a id="__codelineno-2-6" name="__codelineno-2-6" href="#__codelineno-2-6"></a><span class="k">class</span> <span class="nc">CifarDictDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span>
</span><span id="__span-2-7"><a id="__codelineno-2-7" name="__codelineno-2-7" href="#__codelineno-2-7"></a> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">train</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span>
</span><span id="__span-2-8"><a id="__codelineno-2-8" name="__codelineno-2-8" href="#__codelineno-2-8"></a> <span class="bp">self</span><span class="o">.</span><span class="n">base_data</span> <span class="o">=</span> <span class="n">CIFAR10</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="s1">&#39;./cifar10&#39;</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="n">train</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">ToTensor</span><span class="p">())</span>
</span><span id="__span-2-9"><a id="__codelineno-2-9" name="__codelineno-2-9" href="#__codelineno-2-9"></a>
</span><span id="__span-2-10"><a id="__codelineno-2-10" name="__codelineno-2-10" href="#__codelineno-2-10"></a> <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">item</span><span class="p">):</span>
</span><span id="__span-2-11"><a id="__codelineno-2-11" name="__codelineno-2-11" href="#__codelineno-2-11"></a> <span class="n">image</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">base_data</span><span class="p">[</span><span class="n">item</span><span class="p">]</span>
</span><span id="__span-2-12"><a id="__codelineno-2-12" name="__codelineno-2-12" href="#__codelineno-2-12"></a> <span class="k">return</span> <span class="p">{</span>
</span><span id="__span-2-13"><a id="__codelineno-2-13" name="__codelineno-2-13" href="#__codelineno-2-13"></a> <span class="s1">&#39;image&#39;</span><span class="p">:</span> <span class="n">image</span><span class="p">,</span>
</span><span id="__span-2-14"><a id="__codelineno-2-14" name="__codelineno-2-14" href="#__codelineno-2-14"></a> <span class="s1">&#39;label&#39;</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">scalar_tensor</span><span class="p">(</span><span class="n">label</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">)</span>
</span><span id="__span-2-15"><a id="__codelineno-2-15" name="__codelineno-2-15" href="#__codelineno-2-15"></a> <span class="p">}</span>
</span><span id="__span-2-16"><a id="__codelineno-2-16" name="__codelineno-2-16" href="#__codelineno-2-16"></a>
</span><span id="__span-2-17"><a id="__codelineno-2-17" name="__codelineno-2-17" href="#__codelineno-2-17"></a> <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
</span><span id="__span-2-18"><a id="__codelineno-2-18" name="__codelineno-2-18" href="#__codelineno-2-18"></a> <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">base_data</span><span class="p">)</span>
</span></code></pre></div>
<h5 id="collator">Collator</h5>
<p>An example <a href="https://pytorch.org/data/main/generated/torchdata.datapipes.iter.Collator.html">collate function</a> that stacks images and labels</p>
Expand Down Expand Up @@ -834,7 +837,7 @@ <h3 id="5-explore-saved-artifacts">5. Explore saved artifacts</h3>
<p>Inside a <code>project_dir</code> you specified in a <code>Accelerator</code> configuration, you will see:</p>
<ul>
<li>Saved checkpoints inside <code>checkpoint</code> directory</li>
<li>In case of logging enabled - logging artifacts in <code>runs</code> directory</li>
<li>In case of logging enabled (<code>log_with</code> parameter for <code>Accelerator</code>) - logging artifacts in <code>runs</code> directory</li>
</ul>


Expand Down
Binary file modified objects.inv
Binary file not shown.
2 changes: 1 addition & 1 deletion search/search_index.json

Large diffs are not rendered by default.

Binary file modified sitemap.xml.gz
Binary file not shown.
66 changes: 66 additions & 0 deletions trainable/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,15 @@
</span>
</a>

</li>

<li class="md-nav__item">
<a href="#xztrainer.trainable.XZTrainable.tracker_config" class="md-nav__link">
<span class="md-ellipsis">
tracker_config
</span>
</a>

</li>

<li class="md-nav__item">
Expand Down Expand Up @@ -1062,6 +1071,15 @@
</span>
</a>

</li>

<li class="md-nav__item">
<a href="#xztrainer.trainable.XZTrainable.tracker_config" class="md-nav__link">
<span class="md-ellipsis">
tracker_config
</span>
</a>

</li>

<li class="md-nav__item">
Expand Down Expand Up @@ -2297,6 +2315,54 @@ <h3 id="xztrainer.trainable.XZTrainable.step" class="doc doc-heading">
<div class="doc doc-object doc-function">


<h3 id="xztrainer.trainable.XZTrainable.tracker_config" class="doc doc-heading">
<code class="highlight language-python"><span class="n">tracker_config</span><span class="p">(</span><span class="n">context</span><span class="p">)</span></code>

</h3>


<div class="doc doc-contents ">

<p>Function returning additional hyperparameters that will be logged to experiment tracker. This function
is called once when training starts.</p>


<p><span class="doc-section-title">Parameters:</span></p>
<table>
<thead>
<tr>
<th>Name</th>
<th>Type</th>
<th>Description</th>
<th>Default</th>
</tr>
</thead>
<tbody>
<tr class="doc-section-item">
<td><code>context</code></td>
<td>
<code><a class="autorefs autorefs-internal" title="xztrainer.context.TrainContext" href="#xztrainer.context.TrainContext">TrainContext</a></code>
</td>
<td>
<div class="doc-md-description">
<p>Current <strong>train</strong> context</p>
</div>
</td>
<td>
<em>required</em>
</td>
</tr>
</tbody>
</table>
<p>Returns: Dictionary with additional parameter names and their values.</p>

</div>

</div>

<div class="doc doc-object doc-function">


<h3 id="xztrainer.trainable.XZTrainable.update_metrics" class="doc doc-heading">
<code class="highlight language-python"><span class="n">update_metrics</span><span class="p">(</span><span class="n">context_type</span><span class="p">,</span> <span class="n">model_outputs</span><span class="p">,</span> <span class="n">metrics</span><span class="p">)</span></code>

Expand Down
44 changes: 42 additions & 2 deletions trainer/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,15 @@
</span>
</a>

</li>

<li class="md-nav__item">
<a href="#xztrainer.model.XZTrainerConfig.tracker_init_kwargs" class="md-nav__link">
<span class="md-ellipsis">
tracker_init_kwargs
</span>
</a>

</li>

</ul>
Expand Down Expand Up @@ -789,6 +798,15 @@
</span>
</a>

</li>

<li class="md-nav__item">
<a href="#xztrainer.model.XZTrainerConfig.tracker_init_kwargs" class="md-nav__link">
<span class="md-ellipsis">
tracker_init_kwargs
</span>
</a>

</li>

</ul>
Expand Down Expand Up @@ -1370,7 +1388,29 @@ <h3 id="xztrainer.model.XZTrainerConfig.tracker_config" class="doc doc-heading">

<div class="doc doc-contents ">

<p>Arbitrary config used for Accelerate experiment tracker. Directly passed to <code>accelerator.init_trackers(..., tracker_config)</code>. See <a href="https://huggingface.co/docs/accelerate/en/usage_guides/tracking">Accelerate docs</a></p>
<p>Arbitrary hyperparameters logged to experiment tracker.</p>
</div>

</div>

<div class="doc doc-object doc-attribute">



<h3 id="xztrainer.model.XZTrainerConfig.tracker_init_kwargs" class="doc doc-heading">
<code class="highlight language-python"><span class="n">tracker_init_kwargs</span><span class="p">:</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">t</span><span class="o">.</span><span class="n">Any</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default_factory</span><span class="o">=</span><span class="nb">dict</span><span class="p">)</span></code>

<span class="doc doc-labels">
<small class="doc doc-label doc-label-class-attribute"><code>class-attribute</code></small>
<small class="doc doc-label doc-label-instance-attribute"><code>instance-attribute</code></small>
</span>

</h3>


<div class="doc doc-contents ">

<p>Arbitrary keyword arguments used for Accelerate experiment tracker. Directly passed to <code>accelerator.init_trackers(...)</code>. See <a href="https://huggingface.co/docs/accelerate/en/usage_guides/tracking">Accelerate docs</a></p>
</div>

</div>
Expand Down Expand Up @@ -1595,7 +1635,7 @@ <h4 id="xztrainer.trainer.XZTrainer.train" class="doc doc-heading">
<tr class="doc-section-item">
<td><code>eval_data</code></td>
<td>
<code><span title="torch.utils.data.Dataset">Dataset</span> | None</code>
<code><span title="typing.Optional">Optional</span>[<span title="torch.utils.data.Dataset">Dataset</span>]</code>
</td>
<td>
<div class="doc-md-description">
Expand Down

0 comments on commit dafc74f

Please sign in to comment.