Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunk committed Feb 5, 2024
1 parent 11eb92c commit 8e05618
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 9 deletions.
4 changes: 2 additions & 2 deletions core/src/test/scala/TrainingSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ class TraininSuite extends munit.FunSuite {
noGrad {
weight.grad.foreach { grad =>
weight -= grad * learningRate
grad.zero()
grad.zero_()
}
bias.grad.foreach { grad =>
weight -= grad * learningRate
grad.zero()
grad.zero_()
}
}
loss
Expand Down
8 changes: 1 addition & 7 deletions core/src/test/scala/torch/DeviceSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,9 @@
package torch

import munit.ScalaCheckSuite
import torch.DeviceType.CUDA
import org.scalacheck.Prop.*
import org.bytedeco.pytorch.global.torch as torch_native
import org.scalacheck.{Arbitrary, Gen}
import org.scalacheck._
import Gen._
import Arbitrary.arbitrary
import DeviceType.CPU
import Generators.{*, given}
import Generators.given

class DeviceSuite extends ScalaCheckSuite {
test("device native roundtrip") {
Expand Down

0 comments on commit 8e05618

Please sign in to comment.