diff --git a/pkg/specter/unitloading.go b/pkg/specter/unitloading.go index 36347a7..6773701 100644 --- a/pkg/specter/unitloading.go +++ b/pkg/specter/unitloading.go @@ -66,6 +66,15 @@ func UnitOf[T any](v T, id UnitID, kind UnitKind, source Source) *WrappingUnit[T } } +func UnwrapUnit[T any](unit Unit) (value T, ok bool) { + w, ok := unit.(*WrappingUnit[T]) + if !ok { + return value, false + } + + return w.wrapped, true +} + func (w *WrappingUnit[T]) ID() UnitID { return w.id } diff --git a/pkg/specter/unitloading_test.go b/pkg/specter/unitloading_test.go index 8c45fe5..d07ec98 100644 --- a/pkg/specter/unitloading_test.go +++ b/pkg/specter/unitloading_test.go @@ -410,3 +410,40 @@ func TestUnitLoaderAdapter(t *testing.T) { assert.Nil(t, units) }) } + +func TestUnwrapUnit(t *testing.T) { + type then[T any] struct { + value T + ok bool + } + type testCase[T any] struct { + name string + when specter.Unit + then then[T] + } + tests := []testCase[string]{ + { + name: "Unwrap of non wrapped unit should return zero value and false", + when: testutils.NewUnitStub("id", "kind", specter.Source{}), + then: then[string]{ + value: "", + ok: false, + }, + }, + { + name: "Unwrap of a wrapped unit should return the value and true", + when: specter.UnitOf("hello", "id", "kind", specter.Source{}), + then: then[string]{ + value: "hello", + ok: true, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotValue, gotOk := specter.UnwrapUnit[string](tt.when) + assert.Equal(t, tt.then.value, gotValue) + assert.Equal(t, tt.then.ok, gotOk) + }) + } +}