001/*-
002 * Copyright 2016 Diamond Light Source Ltd.
003 *
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 */
009
010package org.eclipse.january.dataset;
011
012import java.util.List;
013
014/**
015 * Class to run over a pair of datasets in parallel with NumPy broadcasting of second dataset
016 */
017public class BroadcastSingleIterator extends BroadcastSelfIterator {
018        private int[] bShape;
019        private int[] aStride;
020        private int[] bStride;
021
022        final private int endrank;
023
024        private final int[] aDelta, bDelta;
025        private final int aStep, bStep;
026        private int aMax, bMax;
027        private int aStart, bStart;
028
029        /**
030         * 
031         * @param a
032         * @param b
033         */
034        public BroadcastSingleIterator(Dataset a, Dataset b) {
035                super(a, b);
036
037                int[] aShape = a.getShapeRef();
038                maxShape = aShape;
039                List<int[]> fullShapes = BroadcastUtils.broadcastShapesToMax(maxShape, b.getShapeRef());
040                bShape = fullShapes.remove(0);
041
042                int rank = maxShape.length;
043                endrank = rank - 1;
044
045                bDataset = b.reshape(bShape);
046                int[] aOffset = new int[1];
047                aStride = AbstractDataset.createStrides(aDataset, aOffset);
048                bStride = BroadcastUtils.createBroadcastStrides(bDataset, maxShape);
049
050                pos = new int[rank];
051                aDelta = new int[rank];
052                aStep = aDataset.getElementsPerItem();
053                bDelta = new int[rank];
054                bStep = bDataset.getElementsPerItem();
055                for (int j = endrank; j >= 0; j--) {
056                        aDelta[j] = aStride[j] * aShape[j];
057                        bDelta[j] = bStride[j] * bShape[j];
058                }
059                aStart = aOffset[0];
060                bStart = bDataset.getOffset();
061                aMax = endrank < 0 ? aStep + aStart: Integer.MIN_VALUE;
062                bMax = endrank < 0 ? bStep + bStart: Integer.MIN_VALUE;
063                reset();
064        }
065
066        @Override
067        public boolean hasNext() {
068                int j = endrank;
069                int oldB = bIndex;
070                for (; j >= 0; j--) {
071                        pos[j]++;
072                        aIndex += aStride[j];
073                        bIndex += bStride[j];
074                        if (pos[j] >= maxShape[j]) {
075                                pos[j] = 0;
076                                aIndex -= aDelta[j]; // reset these dimensions
077                                bIndex -= bDelta[j];
078                        } else {
079                                break;
080                        }
081                }
082                if (j == -1) {
083                        if (endrank >= 0) {
084                                return false;
085                        }
086                        aIndex += aStep;
087                        bIndex += bStep;
088                }
089
090                if (aIndex == aMax || bIndex == bMax) {
091                        return false;
092                }
093
094                if (read) {
095                        if (oldB != bIndex) {
096                                if (asDouble) {
097                                        bDouble = bDataset.getElementDoubleAbs(bIndex);
098                                } else {
099                                        bLong = bDataset.getElementLongAbs(bIndex);
100                                }
101                        }
102                }
103
104                return true;
105        }
106
107        /**
108         * @return shape of first broadcasted dataset
109         */
110        public int[] getFirstShape() {
111                return maxShape;
112        }
113
114        /**
115         * @return shape of second broadcasted dataset
116         */
117        public int[] getSecondShape() {
118                return bShape;
119        }
120
121        @Override
122        public void reset() {
123                for (int i = 0; i <= endrank; i++)
124                        pos[i] = 0;
125
126                if (endrank >= 0) {
127                        pos[endrank] = -1;
128                        aIndex = aStart - aStride[endrank];
129                        bIndex = bStart - bStride[endrank];
130                } else {
131                        aIndex = aStart - aStep;
132                        bIndex = bStart - bStep;
133                }
134
135                if (aIndex == 0 || bIndex == 0) { // for zero-ranked datasets
136                        if (read) {
137                                storeCurrentValues();
138                        }
139                }
140        }
141}