Skip to content

Commit

Permalink
Feature/add-tensor-patterns-to-buffer-summary (#270)
Browse files Browse the repository at this point in the history
<img width="1258" alt="Screenshot 2024-12-04 at 12 27 42 PM"
src="https://github.com/user-attachments/assets/ad9161de-6723-4e1e-ba8b-fe8be5ec568b">
  • Loading branch information
aidemsined authored Dec 4, 2024
2 parents 2b680e5 + d531a0b commit 5073cff
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/components/GlobalSwitch.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ interface GlobalSwitchProps {
intent?: Intent;
}

// This exists so that we can properly style intent on the Switch component according to our theme
function GlobalSwitch({ label, checked, onChange, intent = Intent.PRIMARY }: GlobalSwitchProps) {
return (
<Switch
Expand Down
45 changes: 42 additions & 3 deletions src/components/buffer-summary/BufferSummaryBuffer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

import { PopoverPosition, Tooltip } from '@blueprintjs/core';
import { useState } from 'react';
import { useAtom } from 'jotai';
import { useAtom, useAtomValue } from 'jotai';
import { Buffer } from '../../model/APIData';
import { formatSize, toHex } from '../../functions/math';
import { HistoricalTensor } from '../../model/Graph';
import { getBufferColor, getTensorColor } from '../../functions/colorGenerator';
import { selectedAddressAtom, selectedTensorAtom } from '../../store/app';
import { renderMemoryLayoutAtom, selectedAddressAtom, selectedTensorAtom } from '../../store/app';
import { getDimmedColour } from '../../functions/colour';
import useBufferFocus from '../../hooks/useBufferFocus';
import { TensorMemoryLayout } from '../../functions/parseMemoryConfig';

interface BufferSummaryBufferProps {
buffer: Buffer;
Expand All @@ -26,15 +27,23 @@ function BufferSummaryBuffer({ buffer, size, position, tensor }: BufferSummaryBu
const [selectedTensor, setSelectedTensor] = useAtom(selectedTensorAtom);
const [selectedAddress, setSelectedAddress] = useAtom(selectedAddressAtom);

const showPattern = useAtomValue(renderMemoryLayoutAtom);

const { createToast, resetToasts } = useBufferFocus();

const tensorMemoryLayout = tensor?.memory_config?.memory_layout;
const originalColour = tensor ? getTensorColor(tensor.id) : getBufferColor(buffer.address);
const dimmedColour = originalColour ? getDimmedColour(originalColour) : '#000';
const currentColour = (selectedTensor && selectedTensor !== tensor?.id ? dimmedColour : originalColour) ?? '#000';

const styleProps = {
width: `${size}%`,
left: `${position}%`,
backgroundColor: selectedTensor && selectedTensor !== tensor?.id ? dimmedColour : originalColour,
...(showPattern && tensorMemoryLayout
? getBackgroundPattern(tensorMemoryLayout, currentColour)
: {
backgroundColor: currentColour,
}),
};

const clearFocusedBuffer = () => {
Expand Down Expand Up @@ -88,4 +97,34 @@ function BufferSummaryBuffer({ buffer, size, position, tensor }: BufferSummaryBu
);
}

const FG_COLOUR = 'rgba(0, 0, 0, 0.7)';

function getBackgroundPattern(
layout: TensorMemoryLayout,
colour: string,
): { backgroundImage?: string; backgroundSize?: string } {
let pattern = {};

if (layout === TensorMemoryLayout.INTERLEAVED) {
pattern = {
backgroundImage: `radial-gradient(${FG_COLOUR} 0.8px, ${colour} 0.8px)`,
backgroundSize: '4px 4px',
};
}
if (layout === TensorMemoryLayout.BLOCK_SHARDED) {
pattern = {
backgroundImage: `linear-gradient(${FG_COLOUR} 0.4px, transparent 0.4px), linear-gradient(to right, ${FG_COLOUR} 0.4px, ${colour} 0.4px)`,
backgroundSize: '7px 7px',
};
}
if (layout === TensorMemoryLayout.HEIGHT_SHARDED) {
pattern = {
backgroundSize: '6px',
backgroundImage: `repeating-linear-gradient(to right, ${FG_COLOUR}, ${FG_COLOUR} 0.4px, ${colour} 0.4px, ${colour})`,
};
}

return pattern;
}

export default BufferSummaryBuffer;
11 changes: 10 additions & 1 deletion src/components/buffer-summary/BufferSummaryPlotRenderer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import 'styles/components/BufferSummaryPlot.scss';
import ROUTES from '../../definitions/routes';
import isValidNumber from '../../functions/isValidNumber';
import { HistoricalTensorsByOperation } from '../../model/BufferSummary';
import { selectedDeviceAtom, showHexAtom } from '../../store/app';
import { renderMemoryLayoutAtom, selectedDeviceAtom, showHexAtom } from '../../store/app';
import GlobalSwitch from '../GlobalSwitch';

const PLACEHOLDER_ARRAY_SIZE = 30;
Expand All @@ -35,6 +35,7 @@ function BufferSummaryPlotRenderer({ buffersByOperation, tensorListByOperation }
const [hasScrolledToBottom, setHasScrolledToBottom] = useState(false);
const [showHex, setShowHex] = useAtom(showHexAtom);
const deviceId = useAtomValue(selectedDeviceAtom) || 0;
const [renderMemoryLayout, setRenderMemoryLayout] = useAtom(renderMemoryLayoutAtom);
const [isZoomedIn, setIsZoomedIn] = useState(false);
const { data: devices, isLoading: isLoadingDevices } = useDevices();
const scrollElementRef = useRef(null);
Expand Down Expand Up @@ -104,6 +105,14 @@ function BufferSummaryPlotRenderer({ buffersByOperation, tensorListByOperation }
setShowHex(!showHex);
}}
/>

<GlobalSwitch
label='Tensor memory layout overlay'
checked={renderMemoryLayout}
onChange={() => {
setRenderMemoryLayout(!renderMemoryLayout);
}}
/>
</div>

<p className='x-axis-label'>Memory Address</p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ const OperationDetailsComponent: React.FC<OperationDetailsProps> = ({ operationI
setShowCircularBuffer(!showCircularBuffer);
}}
/>
<Switch
<GlobalSwitch
label='Tensor memory layout overlay'
checked={renderMemoryLayoutPattern}
onChange={() => {
Expand Down
7 changes: 5 additions & 2 deletions src/hooks/useAPI.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ export interface DeviceData {
/** @description
* this is a temporary method to fetch all buffers for all operations. it may not be used in the future
*/
const fetchAllBuffers = async (bufferType: BufferType | null, deviceId?: number): Promise<BuffersByOperationData[]> => {
const fetchAllBuffers = async (
bufferType: BufferType | null,
deviceId: number | null,
): Promise<BuffersByOperationData[]> => {
const params = {
buffer_type: bufferType,
device_id: deviceId,
Expand Down Expand Up @@ -286,7 +289,7 @@ export const useNextBuffer = (address: number | null, consumers: number[], query
});
};

export const useBuffers = (bufferType: BufferType, deviceId?: number) => {
export const useBuffers = (bufferType: BufferType, deviceId: number | null) => {
return useQuery({
queryFn: () => fetchAllBuffers(bufferType, deviceId),
queryKey: ['fetch-all-buffers', bufferType, deviceId],
Expand Down

0 comments on commit 5073cff

Please sign in to comment.