diff --git a/src/fastcs/attributes.py b/src/fastcs/attributes.py index e7e44fd7..0cb871e7 100644 --- a/src/fastcs/attributes.py +++ b/src/fastcs/attributes.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Callable from enum import Enum from typing import Any, Generic, Protocol, runtime_checkable @@ -65,6 +66,10 @@ def __init__( self._allowed_values: list[T] | None = allowed_values self.description = description + # A callback to use when setting the datatype to a different value, for example + # changing the units on an int. This should be implemented in the backend. + self._update_datatype_callbacks: list[Callable[[DataType[T]], None]] = [] + @property def datatype(self) -> DataType[T]: return self._datatype @@ -85,6 +90,20 @@ def group(self) -> str | None: def allowed_values(self) -> list[T] | None: return self._allowed_values + def add_update_datatype_callback( + self, callback: Callable[[DataType[T]], None] + ) -> None: + self._update_datatype_callbacks.append(callback) + + def update_datatype(self, datatype: DataType[T]) -> None: + if not isinstance(self._datatype, type(datatype)): + raise ValueError( + f"Attribute datatype must be of type {type(self._datatype)}" + ) + self._datatype = datatype + for callback in self._update_datatype_callbacks: + callback(datatype) + class AttrR(Attribute[T]): """A read-only ``Attribute``.""" diff --git a/src/fastcs/backends/epics/ioc.py b/src/fastcs/backends/epics/ioc.py index dbe40a5f..0f1327e5 100644 --- a/src/fastcs/backends/epics/ioc.py +++ b/src/fastcs/backends/epics/ioc.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import asdict, dataclass from types import MethodType from typing import Any, Literal @@ -15,7 +15,7 @@ enum_value_to_index, ) from fastcs.controller import BaseController -from fastcs.datatypes import Bool, Float, Int, String, T +from fastcs.datatypes import Bool, DataType, Float, Int, String, T from fastcs.exceptions import FastCSException from fastcs.mapping import Mapping @@ -27,6 +27,26 @@ class EpicsIOCOptions: terminal: bool = True +DATATYPE_NAME_TO_RECORD_FIELD = { + "prec": "PREC", + "units": "EGU", + "min": "DRVL", + "max": "DRVH", + "min_alarm": "LOPR", + "max_alarm": "HOPR", + "znam": "ZNAM", + "onam": "ONAM", +} + + +def datatype_to_epics_fields(datatype: DataType) -> dict[str, Any]: + return { + DATATYPE_NAME_TO_RECORD_FIELD[field]: value + for field, value in asdict(datatype).items() + if field in DATATYPE_NAME_TO_RECORD_FIELD + } + + class EpicsIOC: def __init__(self, pv_prefix: str, mapping: Mapping): _add_pvi_info(f"{pv_prefix}:PVI") @@ -184,36 +204,38 @@ def _get_input_record(pv: str, attribute: AttrR) -> RecordWrapper: return builder.mbbIn(pv, **state_keys, **attribute_fields) match attribute.datatype: - case Bool(znam, onam): - return builder.boolIn(pv, ZNAM=znam, ONAM=onam, **attribute_fields) - case Int(units, min, max, min_alarm, max_alarm): - return builder.longIn( + case Bool(): + record = builder.boolIn( + pv, **datatype_to_epics_fields(attribute.datatype), **attribute_fields + ) + case Int(): + record = builder.longIn( pv, - EGU=units, - DRVL=min, - DRVH=max, - LOPR=min_alarm, - HOPR=max_alarm, + **datatype_to_epics_fields(attribute.datatype), **attribute_fields, ) - case Float(prec, units, min, max, min_alarm, max_alarm): - return builder.aIn( + case Float(): + record = builder.aIn( pv, - PREC=prec, - EGU=units, - DRVL=min, - DRVH=max, - LOPR=min_alarm, - HOPR=max_alarm, + **datatype_to_epics_fields(attribute.datatype), **attribute_fields, ) case String(): - return builder.longStringIn(pv, **attribute_fields) + record = builder.longStringIn( + pv, **datatype_to_epics_fields(attribute.datatype), **attribute_fields + ) case _: raise FastCSException( f"Unsupported type {type(attribute.datatype)}: {attribute.datatype}" ) + def datatype_updater(datatype: DataType): + for name, value in datatype_to_epics_fields(datatype).items(): + record.set_field(name, value) + + attribute.add_update_datatype_callback(datatype_updater) + return record + def _create_and_link_write_pv( pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrW[T] @@ -262,41 +284,31 @@ def _get_output_record(pv: str, attribute: AttrW, on_update: Callable) -> Any: ) match attribute.datatype: - case Bool(znam, onam): - return builder.boolOut( + case Bool(): + record = builder.boolOut( pv, - ZNAM=znam, - ONAM=onam, + **datatype_to_epics_fields(attribute.datatype), always_update=True, on_update=on_update, ) - case Int(units, min, max, min_alarm, max_alarm): - return builder.longOut( + case Int(): + record = builder.longOut( pv, always_update=True, on_update=on_update, - EGU=units, - DRVL=min, - DRVH=max, - LOPR=min_alarm, - HOPR=max_alarm, + **datatype_to_epics_fields(attribute.datatype), **attribute_fields, ) - case Float(prec, units, min, max, min_alarm, max_alarm): - return builder.aOut( + case Float(): + record = builder.aOut( pv, always_update=True, on_update=on_update, - PREC=prec, - EGU=units, - DRVL=min, - DRVH=max, - LOPR=min_alarm, - HOPR=max_alarm, + **datatype_to_epics_fields(attribute.datatype), **attribute_fields, ) case String(): - return builder.longStringOut( + record = builder.longStringOut( pv, always_update=True, on_update=on_update, **attribute_fields ) case _: @@ -304,6 +316,13 @@ def _get_output_record(pv: str, attribute: AttrW, on_update: Callable) -> Any: f"Unsupported type {type(attribute.datatype)}: {attribute.datatype}" ) + def datatype_updater(datatype: DataType): + for name, value in datatype_to_epics_fields(datatype).items(): + record.set_field(name, value) + + attribute.add_update_datatype_callback(datatype_updater) + return record + def _create_and_link_command_pvs(pv_prefix: str, mapping: Mapping) -> None: for single_mapping in mapping.get_controller_mappings(): diff --git a/src/fastcs/datatypes.py b/src/fastcs/datatypes.py index f0fc8a81..7ffb8157 100644 --- a/src/fastcs/datatypes.py +++ b/src/fastcs/datatypes.py @@ -12,6 +12,7 @@ AttrCallback = Callable[[T], Awaitable[None]] +@dataclass(frozen=True) # So that we can type hint with dataclass methods class DataType(Generic[T]): """Generic datatype mapping to a python type, with additional metadata.""" diff --git a/tests/backends/epics/test_ioc.py b/tests/backends/epics/test_ioc.py index 5f2a31d6..721bb295 100644 --- a/tests/backends/epics/test_ioc.py +++ b/tests/backends/epics/test_ioc.py @@ -459,3 +459,39 @@ def test_long_pv_names_discarded(mocker: MockerFixture): always_update=True, on_update=mocker.ANY, ) + + +def test_update_datatype(mocker: MockerFixture): + builder = mocker.patch("fastcs.backends.epics.ioc.builder") + + pv_name = f"{DEVICE}:Attr" + + attr_r = AttrR(Int()) + record_r = _get_input_record(pv_name, attr_r) + + builder.longIn.assert_called_once_with(pv_name, **DEFAULT_SCALAR_FIELD_ARGS) + record_r.set_field.assert_not_called() + attr_r.update_datatype(Int(units="m", min=-3)) + record_r.set_field.assert_any_call("EGU", "m") + record_r.set_field.assert_any_call("DRVL", -3) + + with pytest.raises( + ValueError, + match="Attribute datatype must be of type ", + ): + attr_r.update_datatype(String()) # type: ignore + + attr_w = AttrW(Int()) + record_w = _get_output_record(pv_name, attr_w, on_update=mocker.ANY) + + builder.longIn.assert_called_once_with(pv_name, **DEFAULT_SCALAR_FIELD_ARGS) + record_w.set_field.assert_not_called() + attr_w.update_datatype(Int(units="m", min=-3)) + record_w.set_field.assert_any_call("EGU", "m") + record_w.set_field.assert_any_call("DRVL", -3) + + with pytest.raises( + ValueError, + match="Attribute datatype must be of type ", + ): + attr_w.update_datatype(String()) # type: ignore