diff --git a/src/ArcherContainer/ArcherContainer.helpers.tsx b/src/ArcherContainer/ArcherContainer.helpers.tsx
index a32c57c..cf1efb5 100644
--- a/src/ArcherContainer/ArcherContainer.helpers.tsx
+++ b/src/ArcherContainer/ArcherContainer.helpers.tsx
@@ -1,24 +1,32 @@
import {
- ValidShapeTypes,
+ EntityRelationType,
LineType,
- SourceToTargetType,
ShapeType,
- EntityRelationType,
+ SourceToTargetType,
+ ValidShapeTypes,
} from '../types';
import { SourceToTargetsArrayType } from './ArcherContainer.types';
const possibleShapes: Array = ['arrow', 'circle'];
-export const getEndShapeFromStyle = (shapeObj: LineType) => {
- if (!shapeObj.endShape) {
- return possibleShapes[0];
+export const getEndShapeFromStyle = (shapeObj: LineType, shape: ShapeType) => {
+ if (shapeObj.endShape) {
+ const validShape = (Object.keys(shapeObj.endShape) as ValidShapeTypes[]).find((key) => {
+ return possibleShapes.includes(key);
+ });
+
+ if (validShape) {
+ return validShape;
+ }
}
- return (
- (Object.keys(shapeObj.endShape) as ValidShapeTypes[]).filter((key) =>
- possibleShapes.includes(key),
- )[0] || possibleShapes[0]
- );
+ if (shape.arrow) {
+ return 'arrow';
+ } else if (shape.circle) {
+ return 'circle';
+ }
+
+ return possibleShapes[0];
};
export const getSourceToTargets = (
@@ -35,7 +43,7 @@ export const getSourceToTargets = (
};
export const createShapeObj = (style: LineType, endShape: ShapeType) => {
- const chosenEndShape = getEndShapeFromStyle(style);
+ const chosenEndShape = getEndShapeFromStyle(style, endShape);
const shapeObjMap = {
arrow: () => ({
arrow: {
diff --git a/src/ArcherContainer/__tests__/ArcherContainer.test.tsx b/src/ArcherContainer/__tests__/ArcherContainer.test.tsx
index 65bfb9b..5160190 100644
--- a/src/ArcherContainer/__tests__/ArcherContainer.test.tsx
+++ b/src/ArcherContainer/__tests__/ArcherContainer.test.tsx
@@ -1,8 +1,8 @@
-import React from 'react';
import { render } from '@testing-library/react';
-import ArcherContainer from '../ArcherContainer';
-import ArcherElement from '../../ArcherElement/ArcherElement';
+import React from 'react';
import { act } from 'react-dom/test-utils';
+import ArcherElement from '../../ArcherElement/ArcherElement';
+import ArcherContainer from '../ArcherContainer';
const originalConsoleWarn = console.warn;
@@ -89,6 +89,33 @@ describe('ArcherContainer', () => {
expect(screen.baseElement).toMatchSnapshot();
});
+ it('should render the arrow with a circle end when provided in the container', () => {
+ const screen = render(
+
+
+ element 1
+
+
+ element 2
+
+ ,
+ );
+ expect(screen.baseElement).toMatchSnapshot();
+ });
+
it('should render the arrow on both ends', () => {
const screen = render(
diff --git a/src/ArcherContainer/__tests__/__snapshots__/ArcherContainer.test.tsx.snap b/src/ArcherContainer/__tests__/__snapshots__/ArcherContainer.test.tsx.snap
index 1b6ee36..9792333 100644
--- a/src/ArcherContainer/__tests__/__snapshots__/ArcherContainer.test.tsx.snap
+++ b/src/ArcherContainer/__tests__/__snapshots__/ArcherContainer.test.tsx.snap
@@ -289,6 +289,59 @@ exports[`ArcherContainer rendering an svg with the marker element used to draw a
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ element 1
+
+
+ element 2
+
+
+
+
+
diff --git a/src/ArcherContainer/components/Markers.tsx b/src/ArcherContainer/components/Markers.tsx
index eca8f71..a9af1d9 100644
--- a/src/ArcherContainer/components/Markers.tsx
+++ b/src/ArcherContainer/components/Markers.tsx
@@ -1,7 +1,7 @@
import React from 'react';
import { LineType, ShapeType, SourceToTargetType } from '../../types';
import { endShapeDefaultProp } from '../ArcherContainer.constants';
-import { getEndShapeFromStyle, getSourceToTargets, getMarkerId } from '../ArcherContainer.helpers';
+import { getEndShapeFromStyle, getMarkerId, getSourceToTargets } from '../ArcherContainer.helpers';
import { SourceToTargetsArrayType } from '../ArcherContainer.types';
const circleMarker = (style: LineType, endShape: ShapeType) => () => {
@@ -79,7 +79,7 @@ const buildShape = ({
refX: number;
refY: number;
} => {
- const chosenEndShape = getEndShapeFromStyle(style);
+ const chosenEndShape = getEndShapeFromStyle(style, endShape);
const shapeMap = {
circle: circleMarker(style, endShape),