-
This is a question related to issue #51. Placing the question here because it is general in nature. I have defined the following function in the def crossEntropy[
I <: BFloat16 | Float32 | Float64,
O <: NumericRealNN
](
input: Tensor[I],
target: Tensor[O]
): Tensor[I] =
Tensor(
torchNative.cross_entropy(
input.native,
target.native
)
) But I noticed that in the def softmax[In <: DType, Out <: DType](input: Tensor[In], dim: Long)(
dtype: Out = input.dtype
): Tensor[Out] =
val nativeDType =
if dtype == input.dtype then ScalarTypeOptional() else ScalarTypeOptional(dtype.toScalarType)
Tensor(torchNative.softmax(input.native, dim, nativeDType)) I was wondering if their is some rule to use when encoding the types. So my questions are:
TIA |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
We want to be as specific as possible to avoid runtime errors, as many ops i.e. only support floating point inputs. It's not always easy to find out though, as it is encoded in the C++ kernels. The existing One way to find out is to actually run the method with different input dtypes. If it fails, we get a nice stacktrace that often leads us to the kernel implementation where we can see dtype restrictions. SoftMaxKernel.cpp here. We have a way to do that more systematically via property based tests thanks to @davoclavo. See #23 (comment). It's just not applied to all operations yet.
For these cases it's often useful to look into the Python API docs for orientation. I.e. for softmax you can explicitly override the type via For cross_entropy as you've said the output type is the same as the input and there's no parameter to override the |
Beta Was this translation helpful? Give feedback.
We want to be as specific as possible to avoid runtime errors, as many ops i.e. only support floating point inputs.
It's not always easy to find out though, as it is encoded in the C++ kernels. The existing
softmax
for instance is too generic but the updated one #51 (comment) only accepts floats.One way to find out is to actually run the method with different input dtypes. If it fails, we get a nice stacktrace that often leads us to the kernel implementation where we can see dtype restrictions. SoftMaxKernel.cpp here.
W…