Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to specifying position using position_column parameter #6825

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions R-package/R/aliases.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
, "two_round"
, "use_missing"
, "weight_column"
, "position_column"
, "zero_as_missing"
)])
}
Expand Down
1 change: 1 addition & 0 deletions R-package/tests/testthat/test_lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,7 @@ test_that("all parameters are stored correctly with save_model_to_string()", {
, "[label_column: ]"
, "[weight_column: ]"
, "[group_column: ]"
, "[position_column: ]"
, "[ignore_column: ]"
, "[categorical_feature: ]"
, "[forcedbins_filename: ]"
Expand Down
12 changes: 12 additions & 0 deletions docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,18 @@ Dataset Parameters

- **Note**: index starts from ``0`` and it doesn't count the label column when passing type is ``int``, e.g. when label is column\_0 and query\_id is column\_1, the correct parameter is ``query=0``

- ``position_column`` :raw-html:`<a id="position_column" title="Permalink to this parameter" href="#position_column">&#x1F517;&#xFE0E;</a>`, default = ``""``, type = int or string, aliases: ``position``, ``position_id``, ``position_column``

- used to specify the position/position id column

- use number for index, e.g. ``position=0`` means column\_0 is the position id

- add a prefix ``name:`` for column name, e.g. ``position=name:position_id``

- **Note**: works only in case of loading data directly from text file

- **Note**: index starts from ``0`` and it doesn't count the label column when passing type is ``int``, e.g. when label is column\_0 and position\_id is column\_1, the correct parameter is ``position=0``

- ``ignore_column`` :raw-html:`<a id="ignore_column" title="Permalink to this parameter" href="#ignore_column">&#x1F517;&#xFE0E;</a>`, default = ``""``, type = multi-int or string, aliases: ``ignore_feature``, ``blacklist``

- used to specify some ignoring columns in training
Expand Down
6 changes: 6 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle referenc
* \param has_weights Whether the dataset has Metadata weights
* \param has_init_scores Whether the dataset has Metadata initial scores
* \param has_queries Whether the dataset has Metadata queries/groups
* \param has_positions Whether the dataset has Metadata positions/groups
* \param nclasses Number of initial score classes
* \param nthreads Number of external threads that will use the PushRows APIs
* \param omp_max_threads Maximum number of OpenMP threads (-1 for default)
Expand All @@ -178,6 +179,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetInitStreaming(DatasetHandle dataset,
int32_t has_weights,
int32_t has_init_scores,
int32_t has_queries,
int32_t has_positions,
int32_t nclasses,
int32_t nthreads,
int32_t omp_max_threads);
Expand Down Expand Up @@ -233,6 +235,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset,
* \param weight Optional pointer to array with nrow weights
* \param init_score Optional pointer to array with nrow*nclasses initial scores, in column format
* \param query Optional pointer to array with nrow query values
* \param position Optional pointer to array with nrow position values
* \param tid The id of the calling thread, from 0...N-1 threads
* \return 0 when succeed, -1 when failure happens
*/
Expand All @@ -246,6 +249,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsWithMetadata(DatasetHandle dataset,
const float* weight,
const double* init_score,
const int32_t* query,
const int32_t* position,
int32_t tid);

/*!
Expand Down Expand Up @@ -288,6 +292,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
* \param weight Optional pointer to array with nindptr-1 weights
* \param init_score Optional pointer to array with (nindptr-1)*nclasses initial scores, in column format
* \param query Optional pointer to array with nindptr-1 query values
* \param position Optional pointer to array with nindptr-1 position values
* \param tid The id of the calling thread, from 0...N-1 threads
* \return 0 when succeed, -1 when failure happens
*/
Expand All @@ -304,6 +309,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSRWithMetadata(DatasetHandle datase
const float* weight,
const double* init_score,
const int32_t* query,
const int32_t* position,
int32_t tid);

/*!
Expand Down
9 changes: 9 additions & 0 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,15 @@ struct Config {
// desc = **Note**: index starts from ``0`` and it doesn't count the label column when passing type is ``int``, e.g. when label is column\_0 and query\_id is column\_1, the correct parameter is ``query=0``
std::string group_column = "";

// type = int or string
// alias = position, position_id, position_column
// desc = used to specify the position/position id column
// desc = use number for index, e.g. ``position=0`` means column\_0 is the position id
// desc = add a prefix ``name:`` for column name, e.g. ``position=name:position_id``
// desc = **Note**: works only in case of loading data directly from text file
// desc = **Note**: index starts from ``0`` and it doesn't count the label column when passing type is ``int``, e.g. when label is column\_0 and position\_id is column\_1, the correct parameter is ``position=0``
std::string position_column = "";

// type = multi-int or string
// alias = ignore_feature, blacklist
// desc = used to specify some ignoring columns in training
Expand Down
54 changes: 33 additions & 21 deletions include/LightGBM/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ class Metadata {
* \param num_data Number of training data
* \param weight_idx Index of weight column, < 0 means doesn't exists
* \param query_idx Index of query id column, < 0 means doesn't exists
* \param position_idx Index of position id column, < 0 means doesn't exists
*/
void Init(data_size_t num_data, int weight_idx, int query_idx);
void Init(data_size_t num_data, int weight_idx, int query_idx, int position_idx);

/*!
* \brief Allocate space for label, weight (if exists), initial score (if exists) and query (if exists)
Expand All @@ -92,9 +93,10 @@ class Metadata {
* \param has_weights Whether the metadata has weights
* \param has_init_scores Whether the metadata has initial scores
* \param has_queries Whether the metadata has queries
* \param has_positions Whether the metadata has positions
* \param nclasses Number of classes for initial scores
*/
void Init(data_size_t num_data, int32_t has_weights, int32_t has_init_scores, int32_t has_queries, int32_t nclasses);
void Init(data_size_t num_data, int32_t has_weights, int32_t has_init_scores, int32_t has_queries, int32_t has_positions, int32_t nclasses);

/*!
* \brief Partition label by used indices
Expand All @@ -120,6 +122,7 @@ class Metadata {
void SetQuery(const ArrowChunkedArray& array);

void SetPosition(const data_size_t* position, data_size_t len);
void SetPosition(const ArrowChunkedArray& array);

/*!
* \brief Set initial scores
Expand Down Expand Up @@ -186,6 +189,15 @@ class Metadata {
queries_[idx] = static_cast<data_size_t>(value);
}

/*!
* \brief Set Position Id for one record
* \param idx Index of this record
* \param value Position Id value of this record
*/
inline void SetPositionAt(data_size_t idx, data_size_t value) {
positions_[idx] = static_cast<data_size_t>(value);
}

/*! \brief Load initial scores from file */
void LoadInitialScore(const std::string& data_filename);

Expand All @@ -197,13 +209,15 @@ class Metadata {
* \param weights Pointer to weight data, or null
* \param init_scores Pointer to init-score data, or null
* \param queries Pointer to query data, or null
* \param positions Pointer to position data, or null
*/
void InsertAt(data_size_t start_index,
data_size_t count,
const float* labels,
const float* weights,
const double* init_scores,
const int32_t* queries);
const int32_t* queries,
const int32_t* positions);

/*!
* \brief Perform any extra operations after all data has been loaded
Expand Down Expand Up @@ -233,24 +247,17 @@ class Metadata {
}
}

/*!
* \brief Get position IDs, if does not exist then return nullptr
* \return Pointer of position IDs
*/
inline const std::string* position_ids() const {
if (!position_ids_.empty()) {
return position_ids_.data();
} else {
return nullptr;
}
}

/*!
* \brief Get Number of different position IDs
* \return number of different position IDs
*/
inline size_t num_position_ids() const {
return position_ids_.size();
if (!positions_.empty()) {
size_t max = *std::max_element(positions_.begin(), positions_.end());
return max + 1;
} else {
return 0;
}
}

/*!
Expand Down Expand Up @@ -354,6 +361,11 @@ class Metadata {
void SetInitScoresFromIterator(It first, It last);
/*! \brief Insert queries at the given index */
void InsertQueries(const data_size_t* queries, data_size_t start_index, data_size_t len);
/*! \brief Set positions from pointers to the first element and the end of an iterator. */
template <typename It>
void SetPositionsFromIterator(It first, It last);
/*! \brief Insert positions at the given index */
void InsertPositions(const data_size_t* positions, data_size_t start_index, data_size_t len);
/*! \brief Set queries from pointers to the first element and the end of an iterator. */
template <typename It>
void SetQueriesFromIterator(It first, It last);
Expand All @@ -371,8 +383,6 @@ class Metadata {
std::vector<label_t> weights_;
/*! \brief Positions data */
std::vector<data_size_t> positions_;
/*! \brief Position identifiers */
std::vector<std::string> position_ids_;
/*! \brief Query boundaries */
std::vector<data_size_t> query_boundaries_;
/*! \brief Query weights */
Expand Down Expand Up @@ -519,6 +529,7 @@ class Dataset {
int32_t has_weights,
int32_t has_init_scores,
int32_t has_queries,
int32_t has_positions,
int32_t nclasses,
int32_t nthreads,
int32_t omp_max_threads) {
Expand All @@ -529,7 +540,7 @@ class Dataset {
omp_max_threads_ = OMP_NUM_THREADS();
}

metadata_.Init(num_data, has_weights, has_init_scores, has_queries, nclasses);
metadata_.Init(num_data, has_weights, has_init_scores, has_queries, has_positions, nclasses);
for (int i = 0; i < num_groups_; ++i) {
feature_groups_[i]->InitStreaming(nthreads, omp_max_threads_);
}
Expand Down Expand Up @@ -623,8 +634,9 @@ class Dataset {
const label_t* labels,
const label_t* weights,
const double* init_scores,
const data_size_t* queries) {
metadata_.InsertAt(start_index, count, labels, weights, init_scores, queries);
const data_size_t* queries,
const data_size_t* positions) {
metadata_.InsertAt(start_index, count, labels, weights, init_scores, queries, positions);
}

inline int RealFeatureIndex(int fidx) const {
Expand Down
2 changes: 2 additions & 0 deletions include/LightGBM/dataset_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class DatasetLoader {
int weight_idx_;
/*! \brief index of group column */
int group_idx_;
/*! \brief index of position column */
int position_idx_;
/*! \brief Mapper from real feature index to used index*/
std::unordered_set<int> ignore_features_;
/*! \brief store feature names */
Expand Down
1 change: 1 addition & 0 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2042,6 +2042,7 @@ def get_params(self) -> Dict[str, Any]:
"two_round",
"use_missing",
"weight_column",
"position_column",
"zero_as_missing",
)
return {k: v for k, v in self.params.items() if k in dataset_params}
Expand Down
14 changes: 11 additions & 3 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,11 @@ class Booster {
Log::Fatal(
"Cannot change group_column after constructed Dataset handle.");
}
if (new_param.count("position_column") &&
new_config.position_column != old_config.position_column) {
Log::Fatal(
"Cannot change position_column after constructed Dataset handle.");
}
if (new_param.count("ignore_column") &&
new_config.ignore_column != old_config.ignore_column) {
Log::Fatal(
Expand Down Expand Up @@ -1114,13 +1119,14 @@ int LGBM_DatasetInitStreaming(DatasetHandle dataset,
int32_t has_weights,
int32_t has_init_scores,
int32_t has_queries,
int32_t has_positions,
int32_t nclasses,
int32_t nthreads,
int32_t omp_max_threads) {
API_BEGIN();
auto p_dataset = reinterpret_cast<Dataset*>(dataset);
auto num_data = p_dataset->num_data();
p_dataset->InitStreaming(num_data, has_weights, has_init_scores, has_queries, nclasses, nthreads, omp_max_threads);
p_dataset->InitStreaming(num_data, has_weights, has_init_scores, has_queries, has_positions, nclasses, nthreads, omp_max_threads);
p_dataset->set_wait_for_manual_finish(true);
API_END();
}
Expand Down Expand Up @@ -1163,6 +1169,7 @@ int LGBM_DatasetPushRowsWithMetadata(DatasetHandle dataset,
const float* weights,
const double* init_scores,
const int32_t* queries,
const int32_t* positions,
int32_t tid) {
API_BEGIN();
#ifdef LABEL_T_USE_DOUBLE
Expand Down Expand Up @@ -1191,7 +1198,7 @@ int LGBM_DatasetPushRowsWithMetadata(DatasetHandle dataset,
}
OMP_THROW_EX();

p_dataset->InsertMetadataAt(start_row, nrow, labels, weights, init_scores, queries);
p_dataset->InsertMetadataAt(start_row, nrow, labels, weights, init_scores, queries, positions);

if (!p_dataset->wait_for_manual_finish() && (start_row + nrow == p_dataset->num_data())) {
p_dataset->FinishLoad();
Expand Down Expand Up @@ -1245,6 +1252,7 @@ int LGBM_DatasetPushRowsByCSRWithMetadata(DatasetHandle dataset,
const float* weights,
const double* init_scores,
const int32_t* queries,
const int32_t* positions,
int32_t tid) {
API_BEGIN();
#ifdef LABEL_T_USE_DOUBLE
Expand Down Expand Up @@ -1274,7 +1282,7 @@ int LGBM_DatasetPushRowsByCSRWithMetadata(DatasetHandle dataset,
}
OMP_THROW_EX();

p_dataset->InsertMetadataAt(static_cast<int32_t>(start_row), nrow, labels, weights, init_scores, queries);
p_dataset->InsertMetadataAt(static_cast<int32_t>(start_row), nrow, labels, weights, init_scores, queries, positions);

if (!p_dataset->wait_for_manual_finish() && (start_row + nrow == static_cast<int64_t>(p_dataset->num_data()))) {
p_dataset->FinishLoad();
Expand Down
9 changes: 9 additions & 0 deletions src/io/config_auto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ const std::unordered_map<std::string, std::string>& Config::alias_table() {
{"query_column", "group_column"},
{"query", "group_column"},
{"query_id", "group_column"},
{"position", "position_column"},
{"position_id", "position_column"},
{"position_column", "position_column"},
{"ignore_feature", "ignore_column"},
{"blacklist", "ignore_column"},
{"cat_feature", "categorical_feature"},
Expand Down Expand Up @@ -274,6 +277,7 @@ const std::unordered_set<std::string>& Config::parameter_set() {
"label_column",
"weight_column",
"group_column",
"position_column",
"ignore_column",
"categorical_feature",
"forcedbins_filename",
Expand Down Expand Up @@ -552,6 +556,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str

GetString(params, "group_column", &group_column);

GetString(params, "position_column", &position_column);

GetString(params, "ignore_column", &ignore_column);

GetString(params, "categorical_feature", &categorical_feature);
Expand Down Expand Up @@ -754,6 +760,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[label_column: " << label_column << "]\n";
str_buf << "[weight_column: " << weight_column << "]\n";
str_buf << "[group_column: " << group_column << "]\n";
str_buf << "[position_column: " << position_column << "]\n";
str_buf << "[ignore_column: " << ignore_column << "]\n";
str_buf << "[categorical_feature: " << categorical_feature << "]\n";
str_buf << "[forcedbins_filename: " << forcedbins_filename << "]\n";
Expand Down Expand Up @@ -883,6 +890,7 @@ const std::unordered_map<std::string, std::vector<std::string>>& Config::paramet
{"label_column", {"label"}},
{"weight_column", {"weight"}},
{"group_column", {"group", "group_id", "query_column", "query", "query_id"}},
{"position_column", {"position", "position_id", "position_column"}},
{"ignore_column", {"ignore_feature", "blacklist"}},
{"categorical_feature", {"cat_feature", "categorical_column", "cat_column", "categorical_features"}},
{"forcedbins_filename", {}},
Expand Down Expand Up @@ -1028,6 +1036,7 @@ const std::unordered_map<std::string, std::string>& Config::ParameterTypes() {
{"label_column", "string"},
{"weight_column", "string"},
{"group_column", "string"},
{"position_column", "string"},
{"ignore_column", "vector<int>"},
{"categorical_feature", "vector<int>"},
{"forcedbins_filename", "string"},
Expand Down
2 changes: 1 addition & 1 deletion src/io/dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Dataset::Dataset(data_size_t num_data) {
CHECK_GT(num_data, 0);
data_filename_ = "noname";
num_data_ = num_data;
metadata_.Init(num_data_, NO_SPECIFIC, NO_SPECIFIC);
metadata_.Init(num_data_, NO_SPECIFIC, NO_SPECIFIC, NO_SPECIFIC);
is_finish_load_ = false;
wait_for_manual_finish_ = false;
group_bin_boundaries_.push_back(0);
Expand Down
Loading
Loading