Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional token_program_id param to get_associated_token_address #117

Merged
merged 1 commit into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### Changed

- Add optional `token_program_id` param to `get_associated_token_address` [(#117)](https://github.com/kevinheavey/solders/pull/117).
- Upgrade Solana deps to 2.0 [(#116)](https://github.com/kevinheavey/solders/pull/116).
- Remove GetStakeActivationResp (no longer exists) [(#116)](https://github.com/kevinheavey/solders/pull/116).

Expand Down
19 changes: 17 additions & 2 deletions crates/token/src/associated.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,29 @@
use pyo3::prelude::*;
use solders_pubkey::Pubkey;
use spl_associated_token_account_client::address::get_associated_token_address as get_ata;
use spl_associated_token_account_client::address::get_associated_token_address_with_program_id as get_ata;

/// Derives the associated token account address for the given wallet address and token mint.
///
/// Args:
/// wallet_address (Pubkey): The address of the wallet that owns the token account.
/// token_mint_address (Pubkey): The token mint.
/// token_program_id (Pubkey | None): The token program ID. Defaults to the SPL Token Program.
///
/// Returns:
/// Pubkey: The associated token address
///
#[pyfunction]
pub fn get_associated_token_address(
wallet_address: &Pubkey,
token_mint_address: &Pubkey,
token_program_id: Option<&Pubkey>,
) -> Pubkey {
get_ata(wallet_address.as_ref(), token_mint_address.as_ref()).into()
get_ata(
wallet_address.as_ref(),
token_mint_address.as_ref(),
token_program_id.map_or(&spl_token::ID, |x| x.as_ref()),
)
.into()
}

pub fn create_associated_mod(py: Python<'_>) -> PyResult<&PyModule> {
Expand Down
6 changes: 5 additions & 1 deletion python/solders/token/associated.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import Optional

from solders.pubkey import Pubkey

def get_associated_token_address(
wallet_address: Pubkey, token_mint_address: Pubkey
wallet_address: Pubkey,
token_mint_address: Pubkey,
token_program_id: Optional[Pubkey] = None,
) -> Pubkey: ...
2 changes: 2 additions & 0 deletions tests/test_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_program_position() -> None:
assert message.program_position(1) == 0
assert message.program_position(2) == 1


def test_program_ids() -> None:
key0 = Pubkey.new_unique()
key1 = Pubkey.new_unique()
Expand All @@ -93,6 +94,7 @@ def test_program_ids() -> None:
)
assert message.program_ids() == [loader2]


def test_message_header_len_constant() -> None:
assert MessageHeader.LENGTH == 3

Expand Down
1 change: 1 addition & 0 deletions tests/test_rpc_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,7 @@ def test_get_slot_leaders() -> None:
)
]


def test_get_supply() -> None:
raw = """{
"jsonrpc": "2.0",
Expand Down
2 changes: 2 additions & 0 deletions tests/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,7 @@ def test_tx_uses_nonce_first_prog_id_not_nonce_fail() -> None:
tx = Transaction([from_keypair, nonce_keypair], message, Hash.default())
assert tx.uses_durable_nonce() is None


def test_tx_uses_nonce_wrong_first_nonce_ix_fail() -> None:
from_keypair = Keypair()
from_pubkey = from_keypair.pubkey()
Expand All @@ -935,6 +936,7 @@ def test_tx_uses_nonce_wrong_first_nonce_ix_fail() -> None:
tx = Transaction([from_keypair, nonce_keypair], message, Hash.default())
assert tx.uses_durable_nonce() is None


def test_tx_keypair_pubkey_mismatch() -> None:
from_keypair = Keypair()
from_pubkey = from_keypair.pubkey()
Expand Down
14 changes: 14 additions & 0 deletions tests/token/test_ata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from solders.pubkey import Pubkey
from solders.token.associated import get_associated_token_address


def test_ata() -> None:
wallet_address = Pubkey.from_string("5d21Nx19eZBThbExCn1ESAk3RGmE8Rdp9PKMWZ2VedSK")
token_mint = Pubkey.from_string("3CqfBkrmRsK3uXZaxktvTeeBkJp4yeFKs4mUi2jhKExz")
assert get_associated_token_address(
wallet_address, token_mint
) == Pubkey.from_string("Aumq2SPVzZccYL3UAhvXoDDkNYLZr2zpyxLuJiyx79te")
token22_id = Pubkey.from_string("TokenzQdBNbLqP5VEhdkAS6EPFLC1PHnBqCXEpPxuEb")
assert get_associated_token_address(
wallet_address, token_mint, token22_id
) == Pubkey.from_string("4xoV4cxTM3GcaWP7bKbUdu2Gp9P9nEgpmCPV8ykFGo4U")
Loading