Skip to content

Commit 6b8b899

Browse files
committed
address selectcols issue #991
1 parent 82b4e7f commit 6b8b899

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

src/interface/data_utils.jl

+13-3
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ function MMI.selectcols(::FI, ::Val{:table}, X, c::Union{Symbol, Integer})
101101
return Tables.getcolumn(cols, c)
102102
end
103103

104-
function MMI.selectcols(::FI, ::Val{:table}, X, c::Union{Colon, AbstractArray})
104+
function MMI.selectcols(::FI, ::Val{:table}, X, c)
105105
if isdataframe(X)
106106
return X[!, c]
107107
else
@@ -115,18 +115,28 @@ end
115115
# utils for `select`*
116116

117117
# project named tuple onto a tuple with only specified `labels` or indices:
118-
function project(t::NamedTuple, labels::AbstractArray{Symbol})
118+
function project(t::NamedTuple, labels::Union{AbstractArray{Symbol},NTuple{<:Any,Symbol}})
119119
return NamedTuple{tuple(labels...)}(t)
120120
end
121121

122122
project(t::NamedTuple, label::Colon) = t
123123
project(t::NamedTuple, label::Symbol) = project(t, [label,])
124124
project(t::NamedTuple, i::Integer) = project(t, [i,])
125125

126-
function project(t::NamedTuple, indices::AbstractArray{<:Integer})
126+
function project(
127+
t::NamedTuple,
128+
indices::AbstractArray{<:Integer},
129+
)
127130
return NamedTuple{tuple(keys(t)[indices]...)}(tuple([t[i] for i in indices]...))
128131
end
129132

133+
function project(
134+
t::NamedTuple,
135+
indices::NTuple{<:Any,<:Integer},
136+
)
137+
return NamedTuple{tuple(keys(t)[[indices...]]...)}(tuple([t[i] for i in indices]...))
138+
end
139+
130140
# utils for selectrows
131141
typename(X) = split(string(supertype(typeof(X))), '.')[end]
132142
isdataframe(X) = typename(X) == "AbstractDataFrame"

test/interface/data_utils.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,9 @@ end
149149
s = schema(tt)
150150
@test nrows(tt) == N
151151

152-
@test selectcols(tt, 4:6) ==
152+
@test selectcols(tt, 4:6) == selectcols(tt, (4, 5, 6)) ==
153+
selectcols(tt, (:x4, :x5, :z)) ==
154+
selectcols(tt, [:x4, :x5, :z]) ==
153155
selectcols(TypedTables.Table(x4=tt.x4, x5=tt.x5, z=tt.z), :)
154156
@test selectcols(tt, [:x1, :z]) ==
155157
selectcols(TypedTables.Table(x1=tt.x1, z=tt.z), :)

0 commit comments

Comments
 (0)