diff --git a/src/ArcherContainer/ArcherContainer.helpers.tsx b/src/ArcherContainer/ArcherContainer.helpers.tsx index a32c57c..cf1efb5 100644 --- a/src/ArcherContainer/ArcherContainer.helpers.tsx +++ b/src/ArcherContainer/ArcherContainer.helpers.tsx @@ -1,24 +1,32 @@ import { - ValidShapeTypes, + EntityRelationType, LineType, - SourceToTargetType, ShapeType, - EntityRelationType, + SourceToTargetType, + ValidShapeTypes, } from '../types'; import { SourceToTargetsArrayType } from './ArcherContainer.types'; const possibleShapes: Array = ['arrow', 'circle']; -export const getEndShapeFromStyle = (shapeObj: LineType) => { - if (!shapeObj.endShape) { - return possibleShapes[0]; +export const getEndShapeFromStyle = (shapeObj: LineType, shape: ShapeType) => { + if (shapeObj.endShape) { + const validShape = (Object.keys(shapeObj.endShape) as ValidShapeTypes[]).find((key) => { + return possibleShapes.includes(key); + }); + + if (validShape) { + return validShape; + } } - return ( - (Object.keys(shapeObj.endShape) as ValidShapeTypes[]).filter((key) => - possibleShapes.includes(key), - )[0] || possibleShapes[0] - ); + if (shape.arrow) { + return 'arrow'; + } else if (shape.circle) { + return 'circle'; + } + + return possibleShapes[0]; }; export const getSourceToTargets = ( @@ -35,7 +43,7 @@ export const getSourceToTargets = ( }; export const createShapeObj = (style: LineType, endShape: ShapeType) => { - const chosenEndShape = getEndShapeFromStyle(style); + const chosenEndShape = getEndShapeFromStyle(style, endShape); const shapeObjMap = { arrow: () => ({ arrow: { diff --git a/src/ArcherContainer/__tests__/ArcherContainer.test.tsx b/src/ArcherContainer/__tests__/ArcherContainer.test.tsx index 65bfb9b..5160190 100644 --- a/src/ArcherContainer/__tests__/ArcherContainer.test.tsx +++ b/src/ArcherContainer/__tests__/ArcherContainer.test.tsx @@ -1,8 +1,8 @@ -import React from 'react'; import { render } from '@testing-library/react'; -import ArcherContainer from '../ArcherContainer'; -import ArcherElement from '../../ArcherElement/ArcherElement'; +import React from 'react'; import { act } from 'react-dom/test-utils'; +import ArcherElement from '../../ArcherElement/ArcherElement'; +import ArcherContainer from '../ArcherContainer'; const originalConsoleWarn = console.warn; @@ -89,6 +89,33 @@ describe('ArcherContainer', () => { expect(screen.baseElement).toMatchSnapshot(); }); + it('should render the arrow with a circle end when provided in the container', () => { + const screen = render( + + +
element 1
+
+ +
element 2
+
+
, + ); + expect(screen.baseElement).toMatchSnapshot(); + }); + it('should render the arrow on both ends', () => { const screen = render( diff --git a/src/ArcherContainer/__tests__/__snapshots__/ArcherContainer.test.tsx.snap b/src/ArcherContainer/__tests__/__snapshots__/ArcherContainer.test.tsx.snap index 1b6ee36..9792333 100644 --- a/src/ArcherContainer/__tests__/__snapshots__/ArcherContainer.test.tsx.snap +++ b/src/ArcherContainer/__tests__/__snapshots__/ArcherContainer.test.tsx.snap @@ -289,6 +289,59 @@ exports[`ArcherContainer rendering an svg with the marker element used to draw a `; + +exports[`ArcherContainer rendering an svg with the marker element used to draw an svg arrow should render the arrow with a circle end when provided in the container 1`] = ` + +
+
+ + + + + + + + + + +
+
+ element 1 +
+
+ element 2 +
+
+
+
+ +`; + exports[`ArcherContainer rendering an svg with the marker element used to draw an svg arrow should render the arrow with an arrow end by default 1`] = `
diff --git a/src/ArcherContainer/components/Markers.tsx b/src/ArcherContainer/components/Markers.tsx index eca8f71..a9af1d9 100644 --- a/src/ArcherContainer/components/Markers.tsx +++ b/src/ArcherContainer/components/Markers.tsx @@ -1,7 +1,7 @@ import React from 'react'; import { LineType, ShapeType, SourceToTargetType } from '../../types'; import { endShapeDefaultProp } from '../ArcherContainer.constants'; -import { getEndShapeFromStyle, getSourceToTargets, getMarkerId } from '../ArcherContainer.helpers'; +import { getEndShapeFromStyle, getMarkerId, getSourceToTargets } from '../ArcherContainer.helpers'; import { SourceToTargetsArrayType } from '../ArcherContainer.types'; const circleMarker = (style: LineType, endShape: ShapeType) => () => { @@ -79,7 +79,7 @@ const buildShape = ({ refX: number; refY: number; } => { - const chosenEndShape = getEndShapeFromStyle(style); + const chosenEndShape = getEndShapeFromStyle(style, endShape); const shapeMap = { circle: circleMarker(style, endShape),