Skip to content

Commit

Permalink
Changes from review.
Browse files Browse the repository at this point in the history
Commands now raise an error if initiated while already running.
The command can spawn a task if it should run for a longer length of
time, or needs to be ran multiple times concurrently.
  • Loading branch information
evalott100 committed Mar 3, 2025
1 parent 8cc9ffd commit d0df0b3
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 78 deletions.
4 changes: 2 additions & 2 deletions src/fastcs/transport/epics/ca/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from fastcs.controller import Controller
from fastcs.transport.adapter import TransportAdapter
from fastcs.transport.epics.ca.ioc import EpicsIOC
from fastcs.transport.epics.ca.ioc import EpicsCAIOC
from fastcs.transport.epics.ca.options import EpicsCAOptions
from fastcs.transport.epics.docs import EpicsDocs
from fastcs.transport.epics.gui import EpicsGUI
Expand All @@ -19,7 +19,7 @@ def __init__(
self._loop = loop
self._options = options or EpicsCAOptions()
self._pv_prefix = self.options.ioc.pv_prefix
self._ioc = EpicsIOC(
self._ioc = EpicsCAIOC(
self.options.ioc.pv_prefix,
controller,
self._options.ioc,
Expand Down
2 changes: 1 addition & 1 deletion src/fastcs/transport/epics/ca/ioc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
EPICS_MAX_NAME_LENGTH = 60


class EpicsIOC:
class EpicsCAIOC:
def __init__(
self,
pv_prefix: str,
Expand Down
69 changes: 41 additions & 28 deletions src/fastcs/transport/epics/pva/_pv_handlers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import asyncio
import time
from collections.abc import Callable

import numpy as np
from p4p import Value
from p4p.nt import NTEnum, NTNDArray, NTScalar, NTTable
from p4p.nt.enum import ntenum
from p4p.nt.ndarray import ntndarray
from p4p.server import ServerOperation
from p4p.server.asyncio import SharedPV

Expand All @@ -18,6 +18,7 @@
cast_to_p4p_value,
make_p4p_type,
p4p_alarm_states,
p4p_timestamp_now,
)


Expand All @@ -33,10 +34,13 @@ async def put(self, pv: SharedPV, op: ServerOperation):
[tuple(labelled_row.values()) for labelled_row in value],
dtype=self._attr_w.datatype.structured_dtype,
)
elif hasattr(value, "raw"):
raw_value = value.raw.value
else:
elif isinstance(value, Value):
raw_value = value.todict()["value"]
else:
# Unfortunately these types don't have a `todict`,
# while our `buildType` fields don't have a `.raw`.
assert isinstance(value, ntenum | ntndarray)
raw_value = value.raw.value # type:ignore

cast_value = cast_from_p4p_value(self._attr_w, raw_value)

Expand All @@ -50,45 +54,47 @@ async def put(self, pv: SharedPV, op: ServerOperation):
class CommandPvHandler:
def __init__(self, command: Callable):
self._command = command
self._task_started_event = asyncio.Event()
self._task_in_progress = False

async def _run_command(self, pv: SharedPV):
self._task_started_event.set()
self._task_started_event.clear()
async def _run_command(self) -> dict:
self._task_in_progress = True

kwargs = {}
try:
await self._command()
except Exception as e:
kwargs.update(
p4p_alarm_states(MAJOR_ALARM_SEVERITY, RECORD_ALARM_STATUS, str(e))
alarm_states = p4p_alarm_states(
MAJOR_ALARM_SEVERITY, RECORD_ALARM_STATUS, str(e)
)
else:
kwargs.update(p4p_alarm_states())
alarm_states = p4p_alarm_states()

value = NTScalar("?").wrap({"value": False, **kwargs})
timestamp = time.time()
pv.close()
pv.open(value, timestamp=timestamp)
pv.post(value, timestamp=timestamp)
self._task_in_progress = False
return alarm_states

async def put(self, pv: SharedPV, op: ServerOperation):
value = op.value()
raw_value = value.raw.value
raw_value = value["value"]

if raw_value is True:
asyncio.create_task(self._run_command(pv))
await self._task_started_event.wait()

# Flip to true once command task starts
pv.post(value, timestamp=time.time())
op.done()
if self._task_in_progress:
raise RuntimeError(
"Received request to run command but it is already in progress. "
"Maybe the command should spawn an asyncio task?"
)

# Flip to true once command task starts
pv.post({"value": True, **p4p_timestamp_now(), **p4p_alarm_states()})
op.done()
alarm_states = await self._run_command()
pv.post({"value": False, **p4p_timestamp_now(), **alarm_states})
else:
raise RuntimeError("Commands should only take the value `True`.")

Check warning on line 91 in src/fastcs/transport/epics/pva/_pv_handlers.py

View check run for this annotation

Codecov / codecov/patch

src/fastcs/transport/epics/pva/_pv_handlers.py#L91

Added line #L91 was not covered by tests


def make_shared_pv(attribute: Attribute) -> SharedPV:
initial_value = (
attribute.get()
if isinstance(attribute, AttrRW | AttrR)
if isinstance(attribute, AttrR)
else attribute.datatype.initial_value
)

Expand Down Expand Up @@ -120,10 +126,17 @@ async def on_update(value):


def make_command_pv(command: Callable) -> SharedPV:
type_ = NTScalar.buildType("?", display=True, control=True)

initial = Value(type_, {"value": False, **p4p_alarm_states()})

def _wrap(value: dict):
return Value(type_, value)

shared_pv = SharedPV(
nt=NTScalar("?"),
initial=False,
initial=initial,
handler=CommandPvHandler(command),
wrap=_wrap,
)

return shared_pv
13 changes: 5 additions & 8 deletions src/fastcs/transport/epics/pva/ioc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,14 @@ def get_pv_name(pv_prefix: str, *attribute_names: str) -> str:
async def parse_attributes(
root_pv_prefix: str, controller: Controller
) -> list[StaticProvider]:
providers = []
pvi_tree = PviTree(root_pv_prefix)
provider = StaticProvider(root_pv_prefix)

for single_mapping in controller.get_controller_mappings():
path = single_mapping.controller.path
pv_prefix = get_pv_name(root_pv_prefix, *path)
provider = StaticProvider(pv_prefix)
providers.append(provider)

pvi_tree.add_block(
pvi_tree.add_sub_device(
pv_prefix,
single_mapping.controller.description,
)
Expand All @@ -56,18 +54,17 @@ async def parse_attributes(
pv_name = get_pv_name(pv_prefix, attr_name)
attribute_pv = make_shared_pv(attribute)
provider.add(pv_name, attribute_pv)
pvi_tree.add_field(pv_name, _attribute_to_access(attribute))
pvi_tree.add_signal(pv_name, _attribute_to_access(attribute))

for attr_name, method in single_mapping.command_methods.items():
pv_name = get_pv_name(pv_prefix, attr_name)
command_pv = make_command_pv(
MethodType(method.fn, single_mapping.controller)
)
provider.add(pv_name, command_pv)
pvi_tree.add_field(pv_name, "x")
pvi_tree.add_signal(pv_name, "x")

providers.append(pvi_tree.make_provider())
return providers
return [provider, pvi_tree.make_provider()]


class P4PIOC:
Expand Down
74 changes: 38 additions & 36 deletions src/fastcs/transport/epics/pva/pvi_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


@dataclass
class _PviFieldInfo:
class _PviSignalInfo:
pv: str
access: AccessModeType

Expand All @@ -34,55 +34,57 @@ def _pv_to_pvi_name(pv: str) -> tuple[str, int | None]:
return _pascal_to_snake(string_without_number), number


class PviBlock(dict[str, "PviBlock"]):
class PviDevice(dict[str, "PviDevice"]):
pv_prefix: str
description: str | None
block_field_info: _PviFieldInfo | None
device_signal_info: _PviSignalInfo | None

def __init__(
self,
pv_prefix: str,
description: str | None = None,
block_field_info: _PviFieldInfo | None = None,
device_signal_info: _PviSignalInfo | None = None,
):
self.pv_prefix = pv_prefix
self.description = description
self.block_field_info = block_field_info
self.device_signal_info = device_signal_info

def __missing__(self, key: str) -> "PviBlock":
new_block = PviBlock(pv_prefix=f"{self.pv_prefix}:{key}")
self[key] = new_block
def __missing__(self, key: str) -> "PviDevice":
new_device = PviDevice(pv_prefix=f"{self.pv_prefix}:{key}")
self[key] = new_device
return self[key]

def get_recursively(self, *args: str) -> "PviBlock":
def get_recursively(self, *args: str) -> "PviDevice":
d = self
for arg in args:
d = d[arg]
return d

def _get_field_infos(self) -> dict[str, _PviFieldInfo]:
block_field_infos: dict[str, _PviFieldInfo] = {}
def _get_signal_infos(self) -> dict[str, _PviSignalInfo]:
device_signal_infos: dict[str, _PviSignalInfo] = {}

for sub_block_name, sub_block in self.items():
if sub_block:
block_field_infos[f"{sub_block_name}:PVI"] = _PviFieldInfo(
pv=f"{sub_block.pv_prefix}:PVI", access="d"
for sub_device_name, sub_device in self.items():
if sub_device:
device_signal_infos[f"{sub_device_name}:PVI"] = _PviSignalInfo(
pv=f"{sub_device.pv_prefix}:PVI", access="d"
)
if sub_block.block_field_info:
block_field_infos[sub_block_name] = sub_block.block_field_info
if sub_device.device_signal_info:
device_signal_infos[sub_device_name] = sub_device.device_signal_info

return block_field_infos
return device_signal_infos

def _make_p4p_raw_value(self) -> dict:
p4p_raw_value = defaultdict(dict)
for pv_leaf, field_info in self._get_field_infos().items():
for pv_leaf, signal_info in self._get_signal_infos().items():
pvi_name, number = _pv_to_pvi_name(pv_leaf.rstrip(":PVI") or pv_leaf)
if number is not None:
if field_info.access not in p4p_raw_value[pvi_name]:
p4p_raw_value[pvi_name][field_info.access] = {}
p4p_raw_value[pvi_name][field_info.access][f"v{number}"] = field_info.pv
if signal_info.access not in p4p_raw_value[pvi_name]:
p4p_raw_value[pvi_name][signal_info.access] = {}
p4p_raw_value[pvi_name][signal_info.access][f"v{number}"] = (
signal_info.pv
)
else:
p4p_raw_value[pvi_name][field_info.access] = field_info.pv
p4p_raw_value[pvi_name][signal_info.access] = signal_info.pv

return p4p_raw_value

Expand Down Expand Up @@ -150,42 +152,42 @@ def make_provider(
SharedPV(initial=self.make_p4p_value()),
)

for sub_block in self.values():
if sub_block:
sub_block.make_provider(provider=provider)
for sub_device in self.values():
if sub_device:
sub_device.make_provider(provider=provider)

return provider


# TODO: This can be dramatically cleaned up after https://github.com/DiamondLightSource/FastCS/issues/122
class PviTree:
def __init__(self, pv_prefix: str):
self._pvi_tree_root: PviBlock = PviBlock(pv_prefix)
self._pvi_tree_root: PviDevice = PviDevice(pv_prefix)

def add_block(
def add_sub_device(
self,
block_pv: str,
device_pv: str,
description: str | None,
):
if ":" not in block_pv:
assert block_pv == self._pvi_tree_root.pv_prefix
if ":" not in device_pv:
assert device_pv == self._pvi_tree_root.pv_prefix
self._pvi_tree_root.description = description
else:
self._pvi_tree_root.get_recursively(
*block_pv.split(":")[1:] # To remove the prefix
*device_pv.split(":")[1:] # To remove the prefix
).description = description

def add_field(
def add_signal(
self,
attribute_pv: str,
access: AccessModeType,
):
leaf_block = self._pvi_tree_root.get_recursively(*attribute_pv.split(":")[1:])
leaf_device = self._pvi_tree_root.get_recursively(*attribute_pv.split(":")[1:])

if leaf_block.block_field_info is not None:
if leaf_device.device_signal_info is not None:
raise ValueError(f"Tried to add the field '{attribute_pv}' twice.")

Check warning on line 188 in src/fastcs/transport/epics/pva/pvi_tree.py

View check run for this annotation

Codecov / codecov/patch

src/fastcs/transport/epics/pva/pvi_tree.py#L188

Added line #L188 was not covered by tests

leaf_block.block_field_info = _PviFieldInfo(pv=attribute_pv, access=access)
leaf_device.device_signal_info = _PviSignalInfo(pv=attribute_pv, access=access)

def make_provider(self) -> StaticProvider:
return self._pvi_tree_root.make_provider()
6 changes: 3 additions & 3 deletions tests/transport/epics/ca/test_softioc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from fastcs.exceptions import FastCSException
from fastcs.transport.epics.ca.ioc import (
EPICS_MAX_NAME_LENGTH,
EpicsIOC,
EpicsCAIOC,
_add_attr_pvi_info,
_add_pvi_info,
_add_sub_controller_pvi_info,
Expand Down Expand Up @@ -204,7 +204,7 @@ def test_ioc(mocker: MockerFixture, controller: Controller):
"fastcs.transport.epics.ca.ioc._add_sub_controller_pvi_info"
)

EpicsIOC(DEVICE, controller)
EpicsCAIOC(DEVICE, controller)

# Check records are created
builder.boolIn.assert_called_once_with(
Expand Down Expand Up @@ -397,7 +397,7 @@ def test_long_pv_names_discarded(mocker: MockerFixture):
long_rw_name = "attr_rw_with_a_reallyreally_long_name_that_is_too_long_for_RBV"
assert long_name_controller.attr_rw_short_name.enabled
assert getattr(long_name_controller, long_attr_name).enabled
EpicsIOC(DEVICE, long_name_controller)
EpicsCAIOC(DEVICE, long_name_controller)
assert long_name_controller.attr_rw_short_name.enabled
assert not getattr(long_name_controller, long_attr_name).enabled

Expand Down
Loading

0 comments on commit d0df0b3

Please sign in to comment.