001/*-
002 *******************************************************************************
003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd.
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 *
009 * Contributors:
010 *    Peter Chang - initial API and implementation and/or initial documentation
011 *******************************************************************************/
012
013package org.eclipse.january.dataset;
014
015import java.util.Arrays;
016import java.util.List;
017
018/**
019 * Class to run over a pair of datasets in parallel with NumPy broadcasting to promote shapes
020 * which have lower rank and outputs to a third dataset
021 */
022public class BroadcastPairIterator extends BroadcastIterator {
023        private int[] aShape;
024        private int[] bShape;
025        private int[] aStride;
026        private int[] bStride;
027        private int[] oStride;
028
029        final private int endrank;
030
031        private final int[] aDelta, bDelta;
032        private final int[] oDelta; // this being non-null means output is different from inputs
033        private final int aStep, bStep, oStep;
034        private int aMax, bMax;
035        private int aStart, bStart, oStart;
036
037        /**
038         * 
039         * @param a
040         * @param b
041         * @param o (can be null for new dataset, a or b)
042         * @param createIfNull
043         */
044        public BroadcastPairIterator(Dataset a, Dataset b, Dataset o, boolean createIfNull) {
045                super(a, b, o);
046                List<int[]> fullShapes = BroadcastUtils.broadcastShapes(a.getShapeRef(), b.getShapeRef(), o == null ? null : o.getShapeRef());
047
048                maxShape = fullShapes.remove(0);
049
050                oStride = null;
051                if (o != null && !Arrays.equals(maxShape, o.getShapeRef())) {
052                        throw new IllegalArgumentException("Output does not match broadcasted shape");
053                }
054                aShape = fullShapes.remove(0);
055                bShape = fullShapes.remove(0);
056
057                int rank = maxShape.length;
058                endrank = rank - 1;
059
060                aDataset = a.reshape(aShape);
061                bDataset = b.reshape(bShape);
062                aStride = BroadcastUtils.createBroadcastStrides(aDataset, maxShape);
063                bStride = BroadcastUtils.createBroadcastStrides(bDataset, maxShape);
064                if (outputA) {
065                        oStride = aStride;
066                        oDelta = null;
067                        oStep = 0;
068                } else if (outputB) {
069                        oStride = bStride;
070                        oDelta = null;
071                        oStep = 0;
072                } else if (o != null) {
073                        oStride = BroadcastUtils.createBroadcastStrides(o, maxShape);
074                        oDelta = new int[rank];
075                        oStep = o.getElementsPerItem();
076                } else if (createIfNull) {
077                        oDataset = BroadcastUtils.createDataset(a, b, maxShape);
078                        oStride = BroadcastUtils.createBroadcastStrides(oDataset, maxShape);
079                        oDelta = new int[rank];
080                        oStep = oDataset.getElementsPerItem();
081                } else {
082                        oDelta = null;
083                        oStep = 0;
084                }
085
086                pos = new int[rank];
087                aDelta = new int[rank];
088                aStep = aDataset.getElementsPerItem();
089                bDelta = new int[rank];
090                bStep = bDataset.getElementsPerItem();
091                for (int j = endrank; j >= 0; j--) {
092                        aDelta[j] = aStride[j] * aShape[j];
093                        bDelta[j] = bStride[j] * bShape[j];
094                        if (oDelta != null) {
095                                oDelta[j] = oStride[j] * maxShape[j];
096                        }
097                }
098                if (endrank < 0) {
099                        aMax = aStep;
100                        bMax = bStep;
101                } else {
102                        aMax = Integer.MIN_VALUE; // use max delta
103                        bMax = Integer.MIN_VALUE;
104                        for (int j = endrank; j >= 0; j--) {
105                                if (aDelta[j] > aMax) {
106                                        aMax = aDelta[j];
107                                }
108                                if (bDelta[j] > bMax) {
109                                        bMax = bDelta[j];
110                                }
111                        }
112                }
113                aStart = aDataset.getOffset();
114                aMax += aStart;
115                bStart = bDataset.getOffset();
116                bMax += bStart;
117                oStart = oDelta == null ? 0 : oDataset.getOffset();
118                reset();
119        }
120
121        @Override
122        public boolean hasNext() {
123                int j = endrank;
124                int oldA = aIndex;
125                int oldB = bIndex;
126                for (; j >= 0; j--) {
127                        pos[j]++;
128                        aIndex += aStride[j];
129                        bIndex += bStride[j];
130                        if (oDelta != null)
131                                oIndex += oStride[j];
132                        if (pos[j] >= maxShape[j]) {
133                                pos[j] = 0;
134                                aIndex -= aDelta[j]; // reset these dimensions
135                                bIndex -= bDelta[j];
136                                if (oDelta != null)
137                                        oIndex -= oDelta[j];
138                        } else {
139                                break;
140                        }
141                }
142                if (j == -1) {
143                        if (endrank >= 0) {
144                                aIndex = aMax;
145                                bIndex = bMax;
146                                return false;
147                        }
148                        aIndex += aStep;
149                        bIndex += bStep;
150                        if (oDelta != null)
151                                oIndex += oStep;
152                }
153                if (outputA) {
154                        oIndex = aIndex;
155                } else if (outputB) {
156                        oIndex = bIndex;
157                }
158
159                if (aIndex == aMax || bIndex == bMax)
160                        return false;
161
162                if (read) {
163                        if (oldA != aIndex) {
164                                if (asDouble) {
165                                        aDouble = aDataset.getElementDoubleAbs(aIndex);
166                                } else {
167                                        aLong = aDataset.getElementLongAbs(aIndex);
168                                }
169                        }
170                        if (oldB != bIndex) {
171                                if (asDouble) {
172                                        bDouble = bDataset.getElementDoubleAbs(bIndex);
173                                } else {
174                                        bLong = bDataset.getElementLongAbs(bIndex);
175                                }
176                        }
177                }
178
179                return true;
180        }
181
182        /**
183         * @return shape of first broadcasted dataset
184         */
185        public int[] getFirstShape() {
186                return aShape;
187        }
188
189        /**
190         * @return shape of second broadcasted dataset
191         */
192        public int[] getSecondShape() {
193                return bShape;
194        }
195
196        @Override
197        public void reset() {
198                for (int i = 0; i <= endrank; i++)
199                        pos[i] = 0;
200
201                if (endrank >= 0) {
202                        pos[endrank] = -1;
203                        aIndex = aStart - aStride[endrank];
204                        bIndex = bStart - bStride[endrank];
205                        oIndex = oStart - (oStride == null ? 0 : oStride[endrank]);
206                } else {
207                        aIndex = aStart - aStep;
208                        bIndex = bStart - bStep;
209                        oIndex = oStart - oStep;
210                }
211
212                if (aIndex == 0 || bIndex == 0) { // for zero-ranked datasets
213                        if (read) {
214                                storeCurrentValues();
215                        }
216                        if (aMax == aIndex)
217                                aMax++;
218                        if (bMax == bIndex)
219                                bMax++;
220                }
221        }
222}