diff --git a/playground/SingleTestCase.tsx b/playground/SingleTestCase.tsx index 9fada20b..28803390 100644 --- a/playground/SingleTestCase.tsx +++ b/playground/SingleTestCase.tsx @@ -1,7 +1,7 @@ import { MermaidDiagram } from "./MermaidDiagram"; export interface TestCase { - type: "class" | "flowchart" | "sequence" | "unsupported"; + type: "class" | "flowchart" | "sequence" | "unsupported" | "state"; name: string; definition: string; } diff --git a/playground/Testcases.tsx b/playground/Testcases.tsx index 0db591d4..d3e06e2c 100644 --- a/playground/Testcases.tsx +++ b/playground/Testcases.tsx @@ -2,6 +2,7 @@ import { Fragment } from "react"; import { FLOWCHART_DIAGRAM_TESTCASES } from "./testcases/flowchart"; import { SEQUENCE_DIAGRAM_TESTCASES } from "./testcases/sequence.ts"; +import { STATE_DIAGRAM_TESTCASES } from "./testcases/state.ts"; import { CLASS_DIAGRAM_TESTCASES } from "./testcases/class.ts"; import { UNSUPPORTED_DIAGRAM_TESTCASES } from "./testcases/unsupported.ts"; @@ -18,6 +19,7 @@ interface TestcasesProps { const Testcases = ({ onChange }: TestcasesProps) => { const testcaseTypes: { name: string; testcases: TestCase[] }[] = [ + { name: "State", testcases: STATE_DIAGRAM_TESTCASES }, { name: "Flowchart", testcases: FLOWCHART_DIAGRAM_TESTCASES }, { name: "Sequence", testcases: SEQUENCE_DIAGRAM_TESTCASES }, { name: "Class", testcases: CLASS_DIAGRAM_TESTCASES }, diff --git a/playground/testcases/state.ts b/playground/testcases/state.ts new file mode 100644 index 00000000..a63ed1a7 --- /dev/null +++ b/playground/testcases/state.ts @@ -0,0 +1,284 @@ +import type { TestCase } from "../SingleTestCase"; + +export const STATE_DIAGRAM_TESTCASES: TestCase[] = [ + { + name: "Declare a state with just an id", + definition: `stateDiagram-v2 + Stateid +`, + type: "state", + }, + { + name: "Declare a state with description", + definition: `stateDiagram-v2 + state "This is a state description" as s2 +`, + type: "state", + }, + { + name: "Declare a state with description using another syntax", + definition: `stateDiagram-v2 + s2 : This is a state description +`, + type: "state", + }, + { + name: "Simple Transition", + definition: `stateDiagram-v2 + s1 --> s2 +`, + type: "state", + }, + { + name: "Simple transition with a text", + definition: `stateDiagram-v2 + s1 --> s2: Transition Description +`, + type: "state", + }, + { + name: "Start and End ", + definition: `stateDiagram-v2 + [*] --> s1 + s1 --> [*] +`, + type: "state", + }, + { + name: "Composite states ", + definition: `stateDiagram-v2 + [*] --> First + state First { + [*] --> second + second --> [*] + } + `, + type: "state", + }, + { + name: "Composite states in deep layers", + definition: `stateDiagram-v2 + [*] --> First + + state First { + [*] --> Second + + state Second { + [*] --> second + second --> Third + + state Third { + [*] --> third + third --> [*] + } + } + } +`, + type: "state", + }, + { + name: "Transition between composite states", + definition: `stateDiagram-v2 + [*] --> First + First --> Second + First --> Third + + state First { + [*] --> fir + fir --> [*] + } + state Second { + [*] --> sec + sec --> [*] + } + state Third { + [*] --> thi + thi --> [*] + } +`, + type: "state", + }, + { + name: "Edge case when two composite states are connected by the same state", + definition: `stateDiagram-v2 + [*] --> First + state First { + [*] --> second + second --> [*] + } + state Second { + [*] --> second + second --> [*] + } +`, + type: "state", + }, + { + name: "Choice", + definition: `stateDiagram-v2 + state if_state <> + [*] --> IsPositive + IsPositive --> if_state + if_state --> False: if n < 0 + if_state --> True : if n >= 0 +`, + type: "state", + }, + { + name: "Choice in a composite state", + definition: `stateDiagram-v2 + [*] --> First + First --> Second + First --> Third + + state First { + [*] --> fir + fir --> [*] + + state if_state <> + [*] --> IsPositive + IsPositive --> if_state + if_state --> False: if n < 0 + if_state --> True : if n >= 0 + } + state Second { + [*] --> sec + sec --> [*] + } + state Third { + [*] --> thi + thi --> [*] + } +`, + type: "state", + }, + + { + name: "Forks", + definition: `stateDiagram-v2 + state fork_state <> + fork_state --> State2 + fork_state --> State3 + [*] --> fork_state + + state join_state <> + State2 --> join_state + State3 --> join_state + join_state --> State4 + State4 --> [*] +`, + type: "state", + }, + { + name: "Notes", + definition: `stateDiagram-v2 + State1: The state with a note + note right of State1 + Important information! You can write + notes. + end note + State1 --> State2 + note left of State2 : This is the note to the left. +`, + type: "state", + }, + { + name: "Note left and right in same state", + definition: `stateDiagram-v2 + State1: The state with a note + note right of State1 + Important information! You can write + notes. + end note + note left of State1 : This is the note to the left. +`, + type: "state", + }, + { + name: "Multiple notes in the same state", + definition: `stateDiagram-v2 + State1: The state with a note + note right of State1 + Important information! You can write + notes. + end note + note left of State1 : Left. + note left of State1 : Join. + note left of State1 : Out. + note right of State1 : Ok!. +`, + type: "state", + }, + { + name: "Notes inside a composite state", + definition: `stateDiagram-v2 + [*] --> First + state First { + [*] --> second + second --> [*] + + note right of second + First is a composite state + end note + + } +`, + type: "state", + }, + { + name: "Concurrency", + definition: `stateDiagram-v2 + [*] --> Active + + state Active { + [*] --> NumLockOff + NumLockOff --> NumLockOn : EvNumLockPressed + NumLockOn --> NumLockOff : EvNumLockPressed + -- + [*] --> CapsLockOff + CapsLockOff --> CapsLockOn : EvCapsLockPressed + CapsLockOn --> CapsLockOff : EvCapsLockPressed + -- + [*] --> ScrollLockOff + ScrollLockOff --> ScrollLockOn : EvScrollLockPressed + ScrollLockOn --> ScrollLockOff : EvScrollLockPressed + } + `, + type: "state", + }, + { + name: "Styling", + definition: `stateDiagram-v2 + direction TB + + classDef notMoving fill:white + classDef movement font-style:italic + classDef badBadEvent fill:#f00,color:white,font-weight:bold,stroke-width:2px,stroke:yellow + + [*]--> Still + Still --> [*] + Still --> Moving + Moving --> Still + Moving --> Crash + Crash --> [*] + + class Still notMoving + class Moving, Crash movement + class Crash badBadEvent + class end badBadEvent +`, + type: "state", + }, + { + name: "Sample 1", + definition: `stateDiagram-v2 + [*] --> Still + Still --> [*] + Still --> Moving + Moving --> Still + Moving --> Crash + Crash --> [*] + `, + type: "state", + }, +]; diff --git a/src/converter/types/state.ts b/src/converter/types/state.ts new file mode 100644 index 00000000..c7bf3df8 --- /dev/null +++ b/src/converter/types/state.ts @@ -0,0 +1,118 @@ +import type { ExcalidrawElementSkeleton } from "@excalidraw/excalidraw/types/data/transform.js"; + +import { GraphConverter } from "../GraphConverter.js"; +import { + transformToExcalidrawContainerSkeleton, + transformToExcalidrawArrowSkeleton, + transformToExcalidrawTextSkeleton, + transformToExcalidrawLineSkeleton, +} from "../transformToExcalidrawSkeleton.js"; + +import type { State } from "../../parser/state.js"; +import type { Arrow } from "../../elementSkeleton.js"; +import type { Point } from "@excalidraw/excalidraw/types/types.js"; + +export const StateToExcalidrawSkeletonConvertor = new GraphConverter({ + converter: (chart: State) => { + const elements: ExcalidrawElementSkeleton[] = []; + + chart.nodes.forEach((node) => { + const groupIds = node.groupId?.split(", "); + + switch (node.type) { + case "ellipse": + case "diamond": + case "rectangle": + const element = transformToExcalidrawContainerSkeleton(node); + + if (node?.subtype !== "note") { + Object.assign(element, { + roundness: { type: 3 }, + }); + } + + if (node.groupId) { + Object.assign(element, { + label: { + ...element.label, + groupIds, + }, + groupIds, + }); + } + + elements.push(element); + break; + case "line": + const line = transformToExcalidrawLineSkeleton(node); + + if (node.groupId) { + Object.assign(line, { + groupIds, + }); + } + + elements.push(line); + break; + case "text": + const text = transformToExcalidrawTextSkeleton(node); + elements.push(text); + break; + default: + throw `unknown type ${node.type}`; + } + }); + + chart.edges.forEach((edge) => { + if (!edge) { + return; + } + + const points = edge.reflectionPoints.map((point: Point) => [ + point.x - edge.reflectionPoints[0].x, + point.y - edge.reflectionPoints[0].y, + ]); + + const arrow: Arrow = { + ...edge, + endArrowhead: "triangle", + points, + }; + + const startVertex = elements.find((e) => e.id === edge.start); + const endVertex = elements.find((e) => e.id === edge.end); + + if (!startVertex || !endVertex) { + return; + } + + const groupIds = edge.groupId?.split(", "); + + if (endVertex.id?.includes("note") || startVertex.id?.includes("note")) { + arrow.endArrowhead = null; + arrow.strokeStyle = "dashed"; + } + + arrow.start = { + id: startVertex.id, + type: "rectangle", + }; + arrow.end = { + id: endVertex.id, + type: "rectangle", + }; + + const arrowSkeleton = transformToExcalidrawArrowSkeleton(arrow); + + if (groupIds) { + Object.assign(arrowSkeleton, { + groupIds, + }); + } + + elements.push(arrowSkeleton); + }); + + return { elements }; + }, +}); diff --git a/src/elementSkeleton.ts b/src/elementSkeleton.ts index 5a40ede9..1ed7abf2 100644 --- a/src/elementSkeleton.ts +++ b/src/elementSkeleton.ts @@ -47,7 +47,7 @@ export type Text = { }; export type Container = { - type: "rectangle" | "ellipse"; + type: "rectangle" | "ellipse" | "diamond"; x: number; y: number; id?: string; @@ -227,7 +227,7 @@ export const createTextSkeletonFromSVG = ( }; export const createContainerSkeletonFromSVG = ( - node: SVGSVGElement | SVGRectElement, + node: SVGSVGElement | SVGRectElement | SVGCircleElement, type: Container["type"], opts: { id?: string; diff --git a/src/graphToExcalidraw.ts b/src/graphToExcalidraw.ts index 52a0675f..deb11b41 100644 --- a/src/graphToExcalidraw.ts +++ b/src/graphToExcalidraw.ts @@ -6,10 +6,12 @@ import { SequenceToExcalidrawSkeletonConvertor } from "./converter/types/sequenc import { Sequence } from "./parser/sequence.js"; import { Flowchart } from "./parser/flowchart.js"; import { Class } from "./parser/class.js"; +import { State } from "./parser/state.js"; import { classToExcalidrawSkeletonConvertor } from "./converter/types/class.js"; +import { StateToExcalidrawSkeletonConvertor } from "./converter/types/state.js"; export const graphToExcalidraw = ( - graph: Flowchart | GraphImage | Sequence | Class, + graph: Flowchart | GraphImage | Sequence | Class | State, options: MermaidOptions = {} ): MermaidToExcalidrawResult => { switch (graph.type) { @@ -25,6 +27,10 @@ export const graphToExcalidraw = ( return SequenceToExcalidrawSkeletonConvertor.convert(graph, options); } + case "state": { + return StateToExcalidrawSkeletonConvertor.convert(graph, options); + } + case "class": { return classToExcalidrawSkeletonConvertor.convert(graph, options); } diff --git a/src/parseMermaid.ts b/src/parseMermaid.ts index 5de62f23..fb7c112c 100644 --- a/src/parseMermaid.ts +++ b/src/parseMermaid.ts @@ -1,10 +1,11 @@ import mermaid from "mermaid"; import { GraphImage } from "./interfaces.js"; -import { DEFAULT_FONT_SIZE, MERMAID_CONFIG } from "./constants.js"; +import { MERMAID_CONFIG } from "./constants.js"; import { encodeEntities } from "./utils.js"; import { Flowchart, parseMermaidFlowChartDiagram } from "./parser/flowchart.js"; import { Sequence, parseMermaidSequenceDiagram } from "./parser/sequence.js"; import { Class, parseMermaidClassDiagram } from "./parser/class.js"; +import { State, parseMermaidStateDiagram } from "./parser/state.js"; // Fallback to Svg const convertSvgToGraphImage = (svgContainer: HTMLDivElement) => { @@ -44,7 +45,7 @@ const convertSvgToGraphImage = (svgContainer: HTMLDivElement) => { export const parseMermaid = async ( definition: string -): Promise => { +): Promise => { mermaid.initialize(MERMAID_CONFIG); // Parse the diagram @@ -77,6 +78,12 @@ export const parseMermaid = async ( break; } + case "stateDiagram": { + data = parseMermaidStateDiagram(diagram, svgContainer); + + break; + } + case "classDiagram": { data = parseMermaidClassDiagram(diagram, svgContainer); break; diff --git a/src/parser/flowchart.ts b/src/parser/flowchart.ts index cae51d7c..58ef19c0 100644 --- a/src/parser/flowchart.ts +++ b/src/parser/flowchart.ts @@ -1,7 +1,7 @@ import { computeEdgePositions, + computeElementPosition, entityCodesToText, - getTransformAttr, } from "../utils.js"; import { CONTAINER_STYLE_PROPERTY, @@ -169,53 +169,6 @@ const parseEdge = ( }; }; -// Compute element position -const computeElementPosition = ( - el: Element | null, - containerEl: Element -): Position => { - if (!el) { - throw new Error("Element not found"); - } - - let root = el.parentElement?.parentElement; - - const childElement = el.childNodes[0] as SVGSVGElement; - let childPosition = { x: 0, y: 0 }; - if (childElement) { - const { transformX, transformY } = getTransformAttr(childElement); - - const boundingBox = childElement.getBBox(); - childPosition = { - x: - Number(childElement.getAttribute("x")) || - transformX + boundingBox.x || - 0, - y: - Number(childElement.getAttribute("y")) || - transformY + boundingBox.y || - 0, - }; - } - - const { transformX, transformY } = getTransformAttr(el); - const position = { - x: transformX + childPosition.x, - y: transformY + childPosition.y, - }; - while (root && root.id !== containerEl.id) { - if (root.classList.value === "root" && root.hasAttribute("transform")) { - const { transformX, transformY } = getTransformAttr(root); - position.x += transformX; - position.y += transformY; - } - - root = root.parentElement; - } - - return position; -}; - export const parseMermaidFlowChartDiagram = ( diagram: Diagram, containerEl: Element diff --git a/src/parser/state.ts b/src/parser/state.ts new file mode 100644 index 00000000..25c77ef0 --- /dev/null +++ b/src/parser/state.ts @@ -0,0 +1,668 @@ +import type { Diagram } from "mermaid/dist/Diagram.js"; +import { + type Container, + type Line, + type Node, + createContainerSkeletonFromSVG, +} from "../elementSkeleton.js"; +import { computeEdgePositions, computeElementPosition } from "../utils.js"; +import type { Edge } from "./flowchart.js"; + +export interface State { + type: "state"; + nodes: any[]; + edges: any[]; +} + +// the names are taken from mermaidParser.lineType +export enum LineType { + DOTTED_LINE = 1, + LINE = 0, +} + +// the names are taken from mermaidParser.relationType +export enum RelationType { + AGGREGATION = 0, + COMPOSITION = 2, + DEPENDENCY = 3, + EXTENSION = 1, +} + +export interface StateNode { + width: number; + height: number; +} + +export interface RelationState { + description?: string; + [state: `state${number}`]: { + description: string; + id: string; + start?: boolean; + stmt: "state"; + type: "default"; + }; + stmt: "relation"; +} + +export interface CompositeState { + doc: Array; + description: string; + id: string; + type: "default"; + stmt: "state"; +} + +export interface SingleState { + description: string; + id: string; + type: "default"; + stmt: "state"; +} + +export interface NoteState { + id: string; + note: { + position: string; + text: string; + }; + stmt: "state"; +} + +export interface SpecialState { + id: string; + type: "choice" | "fork" | "join"; + stmt: "state"; +} + +export interface ConcurrencyState { + id: string; + doc: ParsedDoc[]; + stmt: "state"; + type: "divider"; +} + +export type ParsedDoc = + | NoteState + | ConcurrencyState + | SpecialState + | SingleState + | CompositeState + | RelationState; + +const MARGIN_TOP_LINE_X_AXIS = 25; +const DEFAULT_FILL_COLOR = "rgb(236, 236, 255)"; +const DEFAULT_STROKE_COLOR = "rgb(147, 112, 219)"; + +const isNoteState = (node: ParsedDoc): node is NoteState => { + return "note" in node; +}; + +const isCompositeState = (node: ParsedDoc): node is CompositeState => { + return "doc" in node && node.type === "default"; +}; + +const isConcurrencyState = (node: ParsedDoc): node is ConcurrencyState => { + return "doc" in node && node.type === "divider"; +}; + +const isSingleState = (node: ParsedDoc): node is SingleState => { + return "doc" in node === false && "type" in node && node.type === "default"; +}; + +const isSpecialState = (node: ParsedDoc): node is SpecialState => { + return "doc" in node === false && "type" in node && node.type !== "default"; +}; + +const isRelationState = (node: ParsedDoc): node is RelationState => { + return node.stmt === "relation"; +}; + +const createInnerEllipseExcalidrawElement = ( + element: SVGSVGElement, + position: { x: number; y: number; width: number; height: number }, + size = 4 +) => { + const innerEllipse = createContainerSkeletonFromSVG(element, "ellipse", { + id: `${element.id}-inner`, + groupId: element.id, + }); + + innerEllipse.width = size; + innerEllipse.height = size; + + innerEllipse.x = position.x + position.width / 2 - innerEllipse.width / 2; + innerEllipse.y = position.y + position.height / 2 - innerEllipse.height / 2; + + innerEllipse.strokeColor = "black"; + innerEllipse.bgColor = "black"; + + return innerEllipse; +}; + +const createExcalidrawElement = ( + node: SVGSVGElement, + containerEl: Element, + shape: "ellipse" | "rectangle", + additionalProps: Parameters[2] +) => { + const nodePosition = computeElementPosition(node, containerEl); + + const nodeElement = createContainerSkeletonFromSVG(node, shape, { + id: node.id, + ...additionalProps, + }); + + nodeElement.x = nodePosition.x; + nodeElement.y = nodePosition.y; + + return nodeElement; +}; + +const createClusterExcalidrawElement = ( + clusterNode: SVGSVGElement, + containerEl: Element, + state: Extract +) => { + const clusterElementPosition = computeElementPosition( + clusterNode, + containerEl + ); + + const clusterElementSkeleton = createContainerSkeletonFromSVG( + clusterNode, + "rectangle", + { + id: state.id, + label: { text: state.description || state.id, verticalAlign: "top" }, + groupId: state.id, + } + ); + + clusterElementSkeleton.x = clusterElementPosition.x; + clusterElementSkeleton.y = clusterElementPosition.y; + + const topLine: Line = { + type: "line", + startX: clusterElementPosition.x, + startY: clusterElementPosition.y + MARGIN_TOP_LINE_X_AXIS, + endX: clusterElementPosition.x + (clusterElementSkeleton.width || 0), + endY: clusterElementPosition.y + MARGIN_TOP_LINE_X_AXIS, + strokeColor: "black", + strokeWidth: 1, + groupId: state.id, + }; + + return { clusterElementSkeleton, topLine }; +}; + +const getClusterElement = (containerEl: Element, id: string) => { + return containerEl.querySelector(`[id="${id}"]`)!; +}; + +const getRelationElement = (containerEl: Element, id: string) => { + return containerEl.querySelector(`[data-id="${id}"]`)!; +}; + +const computeSpecialType = (specialType: SpecialState): Partial => { + switch (specialType.type) { + case "choice": + return { type: "diamond", label: undefined }; + default: + return { type: "rectangle", label: undefined, bgColor: "#000" }; + } +}; + +const createRelationExcalidrawElement = ( + relation: RelationState["state1" | "state2"], + relationNode: SVGSVGElement, + containerEl: Element, + specialTypes: Record, + groupId?: string +): [Container, Container | null] => { + const shape = relation?.start !== undefined ? "ellipse" : "rectangle"; + const styles = getComputedStyle(relationNode.firstElementChild!); + const haveCustomStyles = relationNode.classList.length > 3; + const label = + relation?.start === undefined + ? { + label: { + text: relation.description || relation.id, + }, + } + : undefined; + + const relationContainer = createExcalidrawElement( + relationNode, + containerEl, + shape, + label + ); + + relationContainer.groupId = groupId; + if (relationContainer.label) { + relationContainer.label.color = styles.color; + } + + if (typeof relation.start === "undefined" && haveCustomStyles) { + relationContainer.strokeColor = + styles.stroke === DEFAULT_STROKE_COLOR ? "#000" : styles.stroke; + relationContainer.bgColor = + styles.fill === DEFAULT_FILL_COLOR ? undefined : styles.fill; + } + + if (relation?.start) { + relationContainer.bgColor = "#000"; + } + + let innerEllipse = null; + + if (relation?.start === false) { + innerEllipse = createInnerEllipseExcalidrawElement(relationNode, { + x: relationContainer.x, + y: relationContainer.y, + width: relationContainer.width!, + height: relationContainer.height!, + }); + + relationContainer.groupId = groupId + ? `${relationContainer.id}, ${groupId}` + : relationContainer.id; + innerEllipse.groupId += groupId ? `, ${groupId}` : ""; + } + + if (specialTypes[relation.id]) { + Object.assign( + relationContainer, + computeSpecialType(specialTypes[relation.id]) + ); + } + + return [relationContainer, innerEllipse]; +}; + +const parseRelation = ( + relation: Extract, + containerEl: Element, + nodes: Array, + processedNodeRelations: Set, + specialTypes: Record, + groupId?: string +) => { + const relationStart = relation.state1; + const relationEnd = relation.state2; + + const relationStartNode = getRelationElement(containerEl, relationStart.id); + const relationEndNode = getRelationElement(containerEl, relationEnd.id); + + // If the relations is not found, is a cluster relation and we don't need to create a node for it + if (!relationStartNode && !relationEndNode) { + return; + } + + if (relationStartNode && !processedNodeRelations.has(relationStart.id)) { + const [relationStartContainer] = createRelationExcalidrawElement( + relationStart, + relationStartNode, + containerEl, + specialTypes, + groupId + ); + + nodes.push(relationStartContainer); + processedNodeRelations.add(relationStart.id); + } + + if (relationEndNode && !processedNodeRelations.has(relationEnd.id)) { + const [relationEndContainer, innerEllipse] = + createRelationExcalidrawElement( + relationEnd, + relationEndNode, + containerEl, + specialTypes, + groupId + ); + + nodes.push(relationEndContainer); + processedNodeRelations.add(relationEnd.id); + + if (innerEllipse) { + nodes.push(innerEllipse); + } + } +}; + +const parseDoc = ( + doc: ParsedDoc[], + containerEl: Element, + nodes: Array = [], + processedNodeRelations: Set = new Set(), + specialTypes: Record = {}, + groupId?: string +) => { + doc.forEach((state) => { + if (isSingleState(state)) { + const singleStateNode = containerEl.querySelector( + `[data-id="${state.id}"]` + )!; + + const stateElement = createExcalidrawElement( + singleStateNode, + containerEl, + "rectangle", + { label: { text: state.description || state.id } } + ); + + stateElement.groupId += groupId ? `, ${groupId}` : ""; + processedNodeRelations.add(state.id); + nodes.push(stateElement); + return; + } + + if (isSpecialState(state)) { + specialTypes[state.id] = state; + return; + } + + if (isRelationState(state)) { + parseRelation( + state, + containerEl, + nodes, + processedNodeRelations, + specialTypes, + groupId + ); + + return; + } + + if (isConcurrencyState(state)) { + const dividerNode = containerEl.querySelector( + `[id*="${state.id}"]` + )!; + + const dividerElement = createExcalidrawElement( + dividerNode, + containerEl, + "rectangle", + { + id: dividerNode.id, + groupId: dividerNode.id, + subtype: "note", + } + ); + + groupId = `${dividerNode.id}${groupId ? `, ${groupId}` : ""}`; + dividerElement.bgColor = "#e9ecef"; + dividerElement.groupId = groupId; + + nodes.push(dividerElement); + + parseDoc( + state.doc, + containerEl, + nodes, + processedNodeRelations, + specialTypes, + groupId + ); + + groupId = undefined; + return; + } + + if (isCompositeState(state)) { + const clusterElement = getClusterElement(containerEl, state.id); + + const { clusterElementSkeleton, topLine } = + createClusterExcalidrawElement(clusterElement, containerEl, state); + + groupId = `${state.id}${groupId ? `, ${groupId}` : ""}`; + clusterElementSkeleton.groupId = groupId; + topLine.groupId = groupId; + nodes.push(clusterElementSkeleton); + nodes.push(topLine); + + parseDoc( + state.doc, + containerEl, + nodes, + processedNodeRelations, + specialTypes, + groupId + ); + + groupId = undefined; + } + }); + + return nodes; +}; + +const parseEdges = (nodes: ParsedDoc[], containerEl: Element): any[] => { + let rootEdgeIndex = 0; + + function parse( + nodes: ParsedDoc[], + retrieveEdgeFromClusterSvg = false, + clusterId?: string + ): any[] { + return nodes + .filter((node) => { + return ( + isCompositeState(node) || + isRelationState(node) || + isConcurrencyState(node) || + isNoteState(node) + ); + }) + .flatMap((node, index) => { + if (isCompositeState(node) || isConcurrencyState(node)) { + const clusters = getClusterElement(containerEl, node.id)?.closest( + ".root" + ); + + const clusterHasOwnEdges = clusters?.hasAttribute("transform"); + clusterId = `${node.id}${clusterId ? `, ${clusterId}` : ""}`; + + const edges = parse(node.doc, clusterHasOwnEdges, clusterId); + + clusterId = undefined; + + return edges; + } else if (node.stmt === "relation") { + const startId = node.state1.id; + const endId = node.state2.id; + + // If the relations node not found, is a relation with a cluster. + const nodeStartElement = + getRelationElement(containerEl, startId) || + getClusterElement(containerEl, startId); + const nodeEndElement = + getRelationElement(containerEl, endId) || + getClusterElement(containerEl, endId); + + const rootContainer = nodeStartElement.closest(".root"); + + if (!rootContainer) { + throw new Error("Root container when parsing edge not found"); + } + + const edges = retrieveEdgeFromClusterSvg + ? rootContainer.querySelector(".edgePaths")?.children + : containerEl.querySelector(".edgePaths")?.children; + + if (!edges) { + throw new Error("Edges not found"); + } + + const edgeStartElement = edges[ + retrieveEdgeFromClusterSvg ? index : rootEdgeIndex + ] as SVGPathElement; + + const position = computeElementPosition( + edgeStartElement, + containerEl + ); + const edgePositionData = computeEdgePositions( + edgeStartElement, + position, + "MCL" + ); + /** + * Edge case where cluster don't have the .edgePaths in SVG, + * so we need to increment the index manually and get from the root container svg + * */ + rootEdgeIndex++; + + return { + start: nodeStartElement.id, + end: nodeEndElement.id, + groupId: clusterId, + label: { + text: node?.description, + }, + ...edgePositionData, + }; + } + + if (isNoteState(node)) { + rootEdgeIndex++; + } + + return []; + }); + } + + return parse(nodes); +}; + +const parseNotes = (doc: ParsedDoc[], containerEl: Element) => { + let rootIndex = 0; + const noteIndex: Record = {}; + const notes: Container[] = []; + const edges: Partial[] = []; + + const processNote = (state: NoteState): [Container, Partial] => { + if (!noteIndex[state.id]) { + noteIndex[state.id] = 0; + } + const noteNodes = Array.from( + containerEl.querySelectorAll( + `[data-id*="${state.id}----note"]` + ) + ); + + const noteNode = noteNodes[noteIndex[state.id]]; + + const noteElement = createExcalidrawElement( + noteNode, + containerEl, + "rectangle", + { + label: { text: state.note.text }, + id: noteNode.id, + subtype: "note", + } + ); + + const rootContainer = noteNode.closest(".root")!; + + const edge = rootContainer.querySelector(".edgePaths")?.children[ + rootIndex + ] as SVGPathElement; + + const position = computeElementPosition(edge, containerEl); + + const edgePositionData = computeEdgePositions(edge, position, "MCL"); + + let startNode = rootContainer.querySelector( + `[data-id*="${state.id}"]:not([data-id*="note"])` + )!; + + const isClusterStartRelation = + startNode.id.includes(`_start`) || startNode.id.includes(`_end`); + + if (isClusterStartRelation) { + startNode = getClusterElement(containerEl, state.id); + } + + const edgeElement: Partial = { + start: startNode.id, + end: noteNode.id, + ...edgePositionData, + }; + + if (state.note.position.includes("left")) { + edgeElement.end = startNode.id; + edgeElement.start = noteNode.id; + } + + noteIndex[state.id]++; + return [noteElement, edgeElement]; + }; + + doc + .filter( + (state) => + isNoteState(state) || isCompositeState(state) || isRelationState(state) + ) + .flatMap((state) => { + if (isNoteState(state)) { + const [noteElement, edgeElement] = processNote(state); + + rootIndex++; + notes.push(noteElement); + edges.push(edgeElement); + } + + if (isCompositeState(state)) { + const { notes: compositeNotes, edges: compositeEdges } = parseNotes( + state.doc, + containerEl + ); + notes.push(...compositeNotes); + edges.push(...compositeEdges); + } + + if (isRelationState(state)) { + rootIndex++; + } + }); + + return { notes, edges }; +}; + +export const parseMermaidStateDiagram = ( + diagram: Diagram, + containerEl: Element +): State => { + // Get mermaid parsed data from parser shared variable `yy` + //@ts-ignore + const mermaidParser = diagram.parser.yy; + // const nodes: Array = []; + const rootDocV2 = mermaidParser.getRootDocV2(); + + console.debug({ + document: rootDocV2, + mermaidParser, + clusters: Array.from(containerEl.querySelectorAll(".clusters")).map( + (el) => el.childNodes + ), + states: mermaidParser.getStates(), + relations: mermaidParser.getRelations(), + classes: mermaidParser.getClasses(), + logDocuments: mermaidParser.logDocuments(), + }); + + const nodes = parseDoc(rootDocV2.doc, containerEl); + const edges = parseEdges(rootDocV2.doc, containerEl); + + const { notes, edges: edgeNotes } = parseNotes(rootDocV2.doc, containerEl); + + nodes.push(...notes); + edges.push(...edgeNotes); + + return { type: "state", nodes, edges }; +}; diff --git a/src/utils.ts b/src/utils.ts index 4cb1caf4..add87eda 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -70,7 +70,8 @@ interface EdgePositionData { // Compute edge postion start, end and points (reflection points) export const computeEdgePositions = ( pathElement: SVGPathElement, - offset: Position = { x: 0, y: 0 } + offset: Position = { x: 0, y: 0 }, + commandsPattern = "LM" ): EdgePositionData => { // Check if the element is a path else throw an error if (pathElement.tagName.toLowerCase() !== "path") { @@ -85,9 +86,9 @@ export const computeEdgePositions = ( throw new Error('Path element does not contain a "d" attribute'); } - // Split the d attribute based on M (Move To) and L (Line To) commands - // eg "M29.383,38.5L29.383,63.5L29.383,83.2" => ["M29.383,38.5", "L29.383,63.5", "L29.383,83.2"] - const commands = dAttr.split(/(?=[LM])/); + // Split the d attribute based on some commands: M (Move To), L (Line To) commands and if specifies C (Curve To) commands + // eg "M29.383,38.5L29.383,63.5L29.383,83.2" => ["M29.383,38.5", "L29.383,63.5", "L29.383,83.2", "C29.383,83.2"] + const commands = dAttr.split(new RegExp(`(?=[${commandsPattern}])`)); // Get the start position from the first commands element => [29.383,38.5] const startPosition = commands[0] @@ -105,11 +106,17 @@ export const computeEdgePositions = ( // These includes the start and end points and also points which are not the same as the previous points const reflectionPoints = commands .map((command) => { + const commandType = command[0]; const coords = command .substring(1) .split(",") .map((coord) => parseFloat(coord)); - return { x: coords[0], y: coords[1] }; + + if (commandType === "C") { + return { x: coords[4], y: coords[5], command: commandType }; + } + + return { x: coords[0], y: coords[1], command: commandType }; }) .filter((point, index, array) => { // Always include the last point @@ -122,6 +129,11 @@ export const computeEdgePositions = ( return false; } + // Exclude the second last point if it's a "C" command because this is a curve and the last point is the end point, so we don't need to include it. + if (index === array.length - 2 && point.command === "C") { + return false; + } + // The below check is exclusively for second last point if ( index === array.length - 2 && @@ -161,3 +173,50 @@ export const computeEdgePositions = ( reflectionPoints, }; }; + +// Compute element position +export const computeElementPosition = ( + el: Element | null, + containerEl: Element +): Position => { + if (!el) { + throw new Error("Element not found"); + } + + let root = el.parentElement?.parentElement; + + const childElement = el.childNodes[0] as SVGSVGElement; + let childPosition = { x: 0, y: 0 }; + if (childElement) { + const { transformX, transformY } = getTransformAttr(childElement); + + const boundingBox = childElement.getBBox(); + childPosition = { + x: + Number(childElement.getAttribute("x")) || + transformX + boundingBox.x || + 0, + y: + Number(childElement.getAttribute("y")) || + transformY + boundingBox.y || + 0, + }; + } + + const { transformX, transformY } = getTransformAttr(el); + const position = { + x: transformX + childPosition.x, + y: transformY + childPosition.y, + }; + while (root && root.id !== containerEl.id) { + if (root.classList.value === "root" && root.hasAttribute("transform")) { + const { transformX, transformY } = getTransformAttr(root); + position.x += transformX; + position.y += transformY; + } + + root = root.parentElement; + } + + return position; +};